예제 #1
0
 def test_qos_exceeds_16bit(self):
     with patch('kombu.common.logger') as logger:
         callback = Mock()
         qos = QoS(callback, 10)
         qos.prev = 100
         qos.set(2 ** 32)
         self.assertTrue(logger.warn.called)
         callback.assert_called_with(prefetch_count=0)
예제 #2
0
 def test_qos_exceeds_16bit(self):
     with patch('kombu.common.logger') as logger:
         callback = Mock()
         qos = QoS(callback, 10)
         qos.prev = 100
         # cannot use 2 ** 32 because of a bug on OSX Py2.5:
         # https://jira.mongodb.org/browse/PYTHON-389
         qos.set(4294967296)
         self.assertTrue(logger.warn.called)
         callback.assert_called_with(prefetch_count=0)
예제 #3
0
 def test_consumer_decrement_eventually(self):
     mconsumer = Mock()
     qos = QoS(mconsumer.qos, 10)
     qos.decrement_eventually()
     self.assertEqual(qos.value, 9)
     qos.value = 0
     qos.decrement_eventually()
     self.assertEqual(qos.value, 0)
예제 #4
0
파일: test_common.py 프로젝트: celery/kombu
 def test_consumer_decrement_eventually(self):
     mconsumer = Mock()
     qos = QoS(mconsumer.qos, 10)
     qos.decrement_eventually()
     assert qos.value == 9
     qos.value = 0
     qos.decrement_eventually()
     assert qos.value == 0
예제 #5
0
    def test_loop_ignores_socket_timeout(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer, 10)
        l.loop(*l.loop_args())
예제 #6
0
    def test_loop_ignores_socket_timeout(self):
        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer.qos, 10)
        l.loop(*l.loop_args())
예제 #7
0
    def test_with_autoscaler_file_descriptor_safety(self):
        # Given: a test celery worker instance with auto scaling
        worker = self.create_worker(
            autoscale=[10, 5],
            use_eventloop=True,
            timer_cls='celery.utils.timer2.Timer',
            threads=False,
        )
        # Given: This test requires a QoS defined on the worker consumer
        worker.consumer.qos = qos = QoS(lambda prefetch_count: prefetch_count,
                                        2)
        qos.update()

        # Given: We have started the worker pool
        worker.pool.start()

        # Then: the worker pool is the same as the autoscaler pool
        auto_scaler = worker.autoscaler
        assert worker.pool == auto_scaler.pool

        # Given: Utilize kombu to get the global hub state
        hub = get_event_loop()
        # Given: Initial call the Async Pool to register events works fine
        worker.pool.register_with_event_loop(hub)

        # Create some mock queue message and read from them
        _keep = [Mock(name=f'req{i}') for i in range(20)]
        [state.task_reserved(m) for m in _keep]
        auto_scaler.body()

        # Simulate a file descriptor from the list is closed by the OS
        # auto_scaler.force_scale_down(5)
        # This actually works -- it releases the semaphore properly
        # Same with calling .terminate() on the process directly
        for fd, proc in worker.pool._pool._fileno_to_outq.items():
            # however opening this fd as a file and closing it will do it
            queue_worker_socket = open(str(fd), "w")
            queue_worker_socket.close()
            break  # Only need to do this once

        # When: Calling again to register with event loop ...
        worker.pool.register_with_event_loop(hub)

        # Then: test did not raise "OSError: [Errno 9] Bad file descriptor!"

        # Finally:  Clean up so the threads before/after fixture passes
        worker.terminate()
        worker.pool.terminate()
예제 #8
0
    def test_loop_when_socket_error(self):
        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        c = self.LoopConsumer()
        c.blueprint.state = RUN
        conn = c.connection = Connection()
        c.connection.obj = c
        c.qos = QoS(c.task_consumer.qos, 10)
        with pytest.raises(socket.error):
            c.loop(*c.loop_args())

        c.blueprint.state = CLOSE
        c.connection = conn
        c.loop(*c.loop_args())
예제 #9
0
    def test_loop_when_socket_error(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        l = Consumer(self.ready_queue, timer=self.timer)
        l.namespace.state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        with self.assertRaises(socket.error):
            l.loop(*l.loop_args())

        l.namespace.state = CLOSE
        l.connection = c
        l.loop(*l.loop_args())
예제 #10
0
    def test_loop_when_socket_error(self):
        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        with self.assertRaises(socket.error):
            l.loop(*l.loop_args())

        l.blueprint.state = CLOSE
        l.connection = c
        l.loop(*l.loop_args())
예제 #11
0
파일: consumer.py 프로젝트: quoter/celery
    def reset_connection(self):
        """Re-establish the broker connection and set up consumers,
        heartbeat and the event dispatcher."""
        debug('Re-establishing connection to the broker...')
        self.stop_consumers(join=False)

        # Clear internal queues to get rid of old messages.
        # They can't be acked anyway, as a delivery tag is specific
        # to the current channel.
        self.ready_queue.clear()
        self.timer.clear()

        # Re-establish the broker connection and setup the task consumer.
        self.connection = self._open_connection()
        info('consumer: Connected to %s.', self.connection.as_uri())
        self.task_consumer = self.app.amqp.TaskConsumer(self.connection,
                                    on_decode_error=self.on_decode_error)
        # QoS: Reset prefetch window.
        self.qos = QoS(self.task_consumer, self.initial_prefetch_count)
        self.qos.update()

        # Setup the process mailbox.
        self.reset_pidbox_node()

        # Flush events sent while connection was down.
        prev_event_dispatcher = self.event_dispatcher
        self.event_dispatcher = self.app.events.Dispatcher(self.connection,
                                                hostname=self.hostname,
                                                enabled=self.send_events)
        if prev_event_dispatcher:
            self.event_dispatcher.copy_buffer(prev_event_dispatcher)
            self.event_dispatcher.flush()

        # Restart heartbeat thread.
        self.restart_heartbeat()

        # reload all task's execution strategies.
        self.update_strategies()

        # We're back!
        self._state = RUN
예제 #12
0
    def start(self, c):
        c.update_strategies()

        # - RabbitMQ 3.3 completely redefines how basic_qos works..
        # This will detect if the new qos smenatics is in effect,
        # and if so make sure the 'apply_global' flag is set on qos updates.
        qos_global = not c.connection.qos_semantics_matches_spec

        # set initial prefetch count
        c.connection.default_channel.basic_qos(
            0, c.initial_prefetch_count, qos_global,
        )

        c.task_consumer = c.app.amqp.TaskConsumer(
            c.connection, on_decode_error=c.on_decode_error,
        )

        def set_prefetch_count(prefetch_count):
            return c.task_consumer.qos(
                prefetch_count=prefetch_count,
                apply_global=qos_global,
            )
        c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
예제 #13
0
    def test_loop(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = Consumer(self.buffer.put, timer=self.timer)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)

        l.loop(*l.loop_args())
        l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        self.assertEqual(l.qos.value, 10)
        l.qos.decrement_eventually()
        self.assertEqual(l.qos.value, 9)
        l.qos.update()
        self.assertEqual(l.qos.value, 9)
        l.task_consumer.qos.assert_called_with(prefetch_count=9)
예제 #14
0
    def test_loop(self):
        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.connection = Connection()
        c.connection.obj = c
        c.qos = QoS(c.task_consumer.qos, 10)

        c.loop(*c.loop_args())
        c.loop(*c.loop_args())
        self.assertTrue(c.task_consumer.consume.call_count)
        c.task_consumer.qos.assert_called_with(prefetch_count=10)
        self.assertEqual(c.qos.value, 10)
        c.qos.decrement_eventually()
        self.assertEqual(c.qos.value, 9)
        c.qos.update()
        self.assertEqual(c.qos.value, 9)
        c.task_consumer.qos.assert_called_with(prefetch_count=9)
예제 #15
0
    def test_receieve_message_eta_isoformat(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
        current_pcount = l.qos.value
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        l.receive_message(m.decode(), m)
        l.timer.stop()
        l.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertGreater(l.qos.value, current_pcount)
        l.timer.stop()
예제 #16
0
    def test_consumer_increment_decrement(self):
        mconsumer = Mock()
        qos = QoS(mconsumer.qos, 10)
        qos.update()
        self.assertEqual(qos.value, 10)
        mconsumer.qos.assert_called_with(prefetch_count=10)
        qos.decrement_eventually()
        qos.update()
        self.assertEqual(qos.value, 9)
        mconsumer.qos.assert_called_with(prefetch_count=9)
        qos.decrement_eventually()
        self.assertEqual(qos.value, 8)
        mconsumer.qos.assert_called_with(prefetch_count=9)
        self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args)

        # Does not decrement 0 value
        qos.value = 0
        qos.decrement_eventually()
        self.assertEqual(qos.value, 0)
        qos.increment_eventually()
        self.assertEqual(qos.value, 0)
예제 #17
0
 def test_exceeds_short(self):
     qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
     qos.update()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
     qos.increment_eventually()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
     qos.increment_eventually()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
     qos.decrement_eventually()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
     qos.decrement_eventually()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
예제 #18
0
 def __init__(self, value):
     self.value = value
     QoS.__init__(self, None, value)
예제 #19
0
 def test_set(self):
     mconsumer = Mock()
     qos = QoS(mconsumer.qos, 10)
     qos.set(12)
     assert qos.prev == 12
     qos.set(qos.prev)
예제 #20
0
    def test_consumer_increment_decrement(self):
        mconsumer = Mock()
        qos = QoS(mconsumer.qos, 10)
        qos.update()
        assert qos.value == 10
        mconsumer.qos.assert_called_with(prefetch_count=10)
        qos.decrement_eventually()
        qos.update()
        assert qos.value == 9
        mconsumer.qos.assert_called_with(prefetch_count=9)
        qos.decrement_eventually()
        assert qos.value == 8
        mconsumer.qos.assert_called_with(prefetch_count=9)
        assert {'prefetch_count': 9} in mconsumer.qos.call_args

        # Does not decrement 0 value
        qos.value = 0
        qos.decrement_eventually()
        assert qos.value == 0
        qos.increment_eventually()
        assert qos.value == 0
예제 #21
0
파일: test_common.py 프로젝트: celery/kombu
 def test_set(self):
     mconsumer = Mock()
     qos = QoS(mconsumer.qos, 10)
     qos.set(12)
     assert qos.prev == 12
     qos.set(qos.prev)
예제 #22
0
 def __init__(self, value):
     self.value = value
     QoS.__init__(self, None, value)
예제 #23
0
 def test_set(self):
     mconsumer = Mock()
     qos = QoS(mconsumer.qos, 10)
     qos.set(12)
     self.assertEqual(qos.prev, 12)
     qos.set(qos.prev)
예제 #24
0
    def test_consumer_increment_decrement(self):
        mconsumer = Mock()
        qos = QoS(mconsumer.qos, 10)
        qos.update()
        self.assertEqual(qos.value, 10)
        mconsumer.qos.assert_called_with(prefetch_count=10)
        qos.decrement_eventually()
        qos.update()
        self.assertEqual(qos.value, 9)
        mconsumer.qos.assert_called_with(prefetch_count=9)
        qos.decrement_eventually()
        self.assertEqual(qos.value, 8)
        mconsumer.qos.assert_called_with(prefetch_count=9)
        self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args)

        # Does not decrement 0 value
        qos.value = 0
        qos.decrement_eventually()
        self.assertEqual(qos.value, 0)
        qos.increment_eventually()
        self.assertEqual(qos.value, 0)
예제 #25
0
파일: consumer.py 프로젝트: wiennat/celery
class Consumer(object):
    """Listen for messages received from the broker and
    move them to the ready queue for task processing.

    :param ready_queue: See :attr:`ready_queue`.
    :param timer: See :attr:`timer`.

    """

    #: The queue that holds tasks ready for immediate processing.
    ready_queue = None

    #: Enable/disable events.
    send_events = False

    #: Optional callback to be called when the connection is established.
    #: Will only be called once, even if the connection is lost and
    #: re-established.
    init_callback = None

    #: The current hostname.  Defaults to the system hostname.
    hostname = None

    #: Initial QoS prefetch count for the task channel.
    initial_prefetch_count = 0

    #: A :class:`celery.events.EventDispatcher` for sending events.
    event_dispatcher = None

    #: The thread that sends event heartbeats at regular intervals.
    #: The heartbeats are used by monitors to detect that a worker
    #: went offline/disappeared.
    heart = None

    #: The broker connection.
    connection = None

    #: The consumer used to consume task messages.
    task_consumer = None

    #: The consumer used to consume broadcast commands.
    broadcast_consumer = None

    #: The process mailbox (kombu pidbox node).
    pidbox_node = None
    _pidbox_node_shutdown = None   # used for greenlets
    _pidbox_node_stopped = None    # used for greenlets

    #: The current worker pool instance.
    pool = None

    #: A timer used for high-priority internal tasks, such
    #: as sending heartbeats.
    timer = None

    # Consumer state, can be RUN or CLOSE.
    _state = None

    def __init__(self, ready_queue,
            init_callback=noop, send_events=False, hostname=None,
            initial_prefetch_count=2, pool=None, app=None,
            timer=None, controller=None, hub=None, amqheartbeat=None,
            **kwargs):
        self.app = app_or_default(app)
        self.connection = None
        self.task_consumer = None
        self.controller = controller
        self.broadcast_consumer = None
        self.ready_queue = ready_queue
        self.send_events = send_events
        self.init_callback = init_callback
        self.hostname = hostname or socket.gethostname()
        self.initial_prefetch_count = initial_prefetch_count
        self.event_dispatcher = None
        self.heart = None
        self.pool = pool
        self.timer = timer or timer2.default_timer
        pidbox_state = AttributeDict(app=self.app,
                                     hostname=self.hostname,
                                     listener=self,     # pre 2.2
                                     consumer=self)
        self.pidbox_node = self.app.control.mailbox.Node(self.hostname,
                                                         state=pidbox_state,
                                                         handlers=Panel.data)
        conninfo = self.app.connection()
        self.connection_errors = conninfo.connection_errors
        self.channel_errors = conninfo.channel_errors

        self._does_info = logger.isEnabledFor(logging.INFO)
        self.strategies = {}
        if hub:
            hub.on_init.append(self.on_poll_init)
        self.hub = hub
        self._quick_put = self.ready_queue.put
        self.amqheartbeat = amqheartbeat
        if self.amqheartbeat is None:
            self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
        if not hub:
            self.amqheartbeat = 0

        if _detect_environment() == 'gevent':
            # there's a gevent bug that causes timeouts to not be reset,
            # so if the connection timeout is exceeded once, it can NEVER
            # connect again.
            self.app.conf.BROKER_CONNECTION_TIMEOUT = None

    def update_strategies(self):
        S = self.strategies
        app = self.app
        loader = app.loader
        hostname = self.hostname
        for name, task in self.app.tasks.iteritems():
            S[name] = task.start_strategy(app, self)
            task.__trace__ = build_tracer(name, task, loader, hostname)

    def start(self):
        """Start the consumer.

        Automatically survives intermittent connection failure,
        and will retry establishing the connection and restart
        consuming messages.

        """

        self.init_callback(self)

        while self._state != CLOSE:
            self.maybe_shutdown()
            try:
                self.reset_connection()
                self.consume_messages()
            except self.connection_errors + self.channel_errors:
                error(RETRY_CONNECTION, exc_info=True)

    def on_poll_init(self, hub):
        hub.update_readers(self.connection.eventmap)
        self.connection.transport.on_poll_init(hub.poller)

    def consume_messages(self, sleep=sleep, min=min, Empty=Empty,
            hbrate=AMQHEARTBEAT_RATE):
        """Consume messages forever (or until an exception is raised)."""

        with self.hub as hub:
            qos = self.qos
            update_qos = qos.update
            update_readers = hub.update_readers
            readers, writers = hub.readers, hub.writers
            poll = hub.poller.poll
            fire_timers = hub.fire_timers
            scheduled = hub.timer._queue
            connection = self.connection
            hb = self.amqheartbeat
            hbtick = connection.heartbeat_check
            on_poll_start = connection.transport.on_poll_start
            on_poll_empty = connection.transport.on_poll_empty
            strategies = self.strategies
            drain_nowait = connection.drain_nowait
            on_task_callbacks = hub.on_task
            keep_draining = connection.transport.nb_keep_draining

            if hb and connection.supports_heartbeats:
                hub.timer.apply_interval(
                    hb * 1000.0 / hbrate, hbtick, (hbrate, ))

            def on_task_received(body, message):
                if on_task_callbacks:
                    [callback() for callback in on_task_callbacks]
                try:
                    name = body['task']
                except (KeyError, TypeError):
                    return self.handle_unknown_message(body, message)
                try:
                    strategies[name](message, body, message.ack_log_error)
                except KeyError as exc:
                    self.handle_unknown_task(body, message, exc)
                except InvalidTaskError as exc:
                    self.handle_invalid_task(body, message, exc)
                #fire_timers()

            self.task_consumer.callbacks = [on_task_received]
            self.task_consumer.consume()

            debug('Ready to accept tasks!')

            while self._state != CLOSE and self.connection:
                # shutdown if signal handlers told us to.
                if state.should_stop:
                    raise SystemExit()
                elif state.should_terminate:
                    raise SystemTerminate()

                # fire any ready timers, this also returns
                # the number of seconds until we need to fire timers again.
                poll_timeout = fire_timers() if scheduled else 1

                # We only update QoS when there is no more messages to read.
                # This groups together qos calls, and makes sure that remote
                # control commands will be prioritized over task messages.
                if qos.prev != qos.value:
                    update_qos()

                update_readers(on_poll_start())
                if readers or writers:
                    connection.more_to_read = True
                    while connection.more_to_read:
                        try:
                            events = poll(poll_timeout)
                        except ValueError:  # Issue 882
                            return
                        if not events:
                            on_poll_empty()
                        for fileno, event in events or ():
                            try:
                                if event & READ:
                                    readers[fileno](fileno, event)
                                if event & WRITE:
                                    writers[fileno](fileno, event)
                                if event & ERR:
                                    for handlermap in readers, writers:
                                        try:
                                            handlermap[fileno](fileno, event)
                                        except KeyError:
                                            pass
                            except (KeyError, Empty):
                                continue
                            except socket.error:
                                if self._state != CLOSE:  # pragma: no cover
                                    raise
                        if keep_draining:
                            drain_nowait()
                            poll_timeout = 0
                        else:
                            connection.more_to_read = False
                else:
                    # no sockets yet, startup is probably not done.
                    sleep(min(poll_timeout, 0.1))

    def on_task(self, task, task_reserved=task_reserved):
        """Handle received task.

        If the task has an `eta` we enter it into the ETA schedule,
        otherwise we move it the ready queue for immediate processing.

        """
        if task.revoked():
            return

        if self._does_info:
            info('Got task from broker: %s', task)

        if self.event_dispatcher.enabled:
            self.event_dispatcher.send('task-received', uuid=task.id,
                    name=task.name, args=safe_repr(task.args),
                    kwargs=safe_repr(task.kwargs),
                    retries=task.request_dict.get('retries', 0),
                    eta=task.eta and task.eta.isoformat(),
                    expires=task.expires and task.expires.isoformat())

        if task.eta:
            eta = timezone.to_system(task.eta) if task.utc else task.eta
            try:
                eta = timer2.to_timestamp(eta)
            except OverflowError as exc:
                error("Couldn't convert eta %s to timestamp: %r. Task: %r",
                      task.eta, exc, task.info(safe=True), exc_info=True)
                task.acknowledge()
            else:
                self.qos.increment_eventually()
                self.timer.apply_at(
                    eta, self.apply_eta_task, (task, ), priority=6,
                )
        else:
            task_reserved(task)
            self._quick_put(task)

    def on_control(self, body, message):
        """Process remote control command message."""
        try:
            self.pidbox_node.handle_message(body, message)
        except KeyError as exc:
            error('No such control command: %s', exc)
        except Exception as exc:
            error('Control command error: %r', exc, exc_info=True)
            self.reset_pidbox_node()

    def apply_eta_task(self, task):
        """Method called by the timer to apply a task with an
        ETA/countdown."""
        task_reserved(task)
        self._quick_put(task)
        self.qos.decrement_eventually()

    def _message_report(self, body, message):
        return MESSAGE_REPORT.format(dump_body(message, body),
                                     safe_repr(message.content_type),
                                     safe_repr(message.content_encoding),
                                     safe_repr(message.delivery_info))

    def handle_unknown_message(self, body, message):
        warn(UNKNOWN_FORMAT, self._message_report(body, message))
        message.reject_log_error(logger, self.connection_errors)

    def handle_unknown_task(self, body, message, exc):
        error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
        message.reject_log_error(logger, self.connection_errors)

    def handle_invalid_task(self, body, message, exc):
        error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
        message.reject_log_error(logger, self.connection_errors)

    def receive_message(self, body, message):
        """Handles incoming messages.

        :param body: The message body.
        :param message: The kombu message object.

        """
        try:
            name = body['task']
        except (KeyError, TypeError):
            return self.handle_unknown_message(body, message)

        try:
            self.strategies[name](message, body, message.ack_log_error)
        except KeyError as exc:
            self.handle_unknown_task(body, message, exc)
        except InvalidTaskError as exc:
            self.handle_invalid_task(body, message, exc)

    def maybe_conn_error(self, fun):
        """Applies function but ignores any connection or channel
        errors raised."""
        try:
            fun()
        except (AttributeError, ) + \
                self.connection_errors + \
                self.channel_errors:
            pass

    def close_connection(self):
        """Closes the current broker connection and all open channels."""

        # We must set self.connection to None here, so
        # that the green pidbox thread exits.
        connection, self.connection = self.connection, None

        if self.task_consumer:
            debug('Closing consumer channel...')
            self.task_consumer = \
                    self.maybe_conn_error(self.task_consumer.close)

        self.stop_pidbox_node()

        if connection:
            debug('Closing broker connection...')
            self.maybe_conn_error(connection.close)

    def stop_consumers(self, close_connection=True, join=True):
        """Stop consuming tasks and broadcast commands, also stops
        the heartbeat thread and event dispatcher.

        :keyword close_connection: Set to False to skip closing the broker
                                    connection.

        """
        if not self._state == RUN:
            return

        if self.heart:
            # Stop the heartbeat thread if it's running.
            debug('Heart: Going into cardiac arrest...')
            self.heart = self.heart.stop()

        debug('Cancelling task consumer...')
        if join and self.task_consumer:
            self.maybe_conn_error(self.task_consumer.cancel)

        if self.event_dispatcher:
            debug('Shutting down event dispatcher...')
            self.event_dispatcher = \
                    self.maybe_conn_error(self.event_dispatcher.close)

        debug('Cancelling broadcast consumer...')
        if join and self.broadcast_consumer:
            self.maybe_conn_error(self.broadcast_consumer.cancel)

        if close_connection:
            self.close_connection()

    def on_decode_error(self, message, exc):
        """Callback called if an error occurs while decoding
        a message received.

        Simply logs the error and acknowledges the message so it
        doesn't enter a loop.

        :param message: The message with errors.
        :param exc: The original exception instance.

        """
        crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
             exc, message.content_type, message.content_encoding,
             dump_body(message, message.body))
        message.ack()

    def reset_pidbox_node(self):
        """Sets up the process mailbox."""
        self.stop_pidbox_node()
        # close previously opened channel if any.
        if self.pidbox_node.channel:
            try:
                self.pidbox_node.channel.close()
            except self.connection_errors + self.channel_errors:
                pass

        if self.pool is not None and self.pool.is_green:
            return self.pool.spawn_n(self._green_pidbox_node)
        self.pidbox_node.channel = self.connection.channel()
        self.broadcast_consumer = self.pidbox_node.listen(
                                        callback=self.on_control)

    def stop_pidbox_node(self):
        if self._pidbox_node_stopped:
            self._pidbox_node_shutdown.set()
            debug('Waiting for broadcast thread to shutdown...')
            self._pidbox_node_stopped.wait()
            self._pidbox_node_stopped = self._pidbox_node_shutdown = None
        elif self.broadcast_consumer:
            debug('Closing broadcast channel...')
            self.broadcast_consumer = \
                self.maybe_conn_error(self.broadcast_consumer.channel.close)

    def _green_pidbox_node(self):
        """Sets up the process mailbox when running in a greenlet
        environment."""
        # THIS CODE IS TERRIBLE
        # Luckily work has already started rewriting the Consumer for 4.0.
        self._pidbox_node_shutdown = threading.Event()
        self._pidbox_node_stopped = threading.Event()
        try:
            with self._open_connection() as conn:
                info('pidbox: Connected to %s.', conn.as_uri())
                self.pidbox_node.channel = conn.default_channel
                self.broadcast_consumer = self.pidbox_node.listen(
                                            callback=self.on_control)
                with self.broadcast_consumer:
                    while not self._pidbox_node_shutdown.isSet():
                        try:
                            conn.drain_events(timeout=1.0)
                        except socket.timeout:
                            pass
        finally:
            self._pidbox_node_stopped.set()

    def reset_connection(self):
        """Re-establish the broker connection and set up consumers,
        heartbeat and the event dispatcher."""
        debug('Re-establishing connection to the broker...')
        self.stop_consumers(join=False)

        # Clear internal queues to get rid of old messages.
        # They can't be acked anyway, as a delivery tag is specific
        # to the current channel.
        self.ready_queue.clear()
        self.timer.clear()

        # Re-establish the broker connection and setup the task consumer.
        self.connection = self._open_connection()
        info('consumer: Connected to %s.', self.connection.as_uri())
        self.task_consumer = self.app.amqp.TaskConsumer(self.connection,
                                    on_decode_error=self.on_decode_error)
        # QoS: Reset prefetch window.
        self.qos = QoS(self.task_consumer, self.initial_prefetch_count)
        self.qos.update()

        # Setup the process mailbox.
        self.reset_pidbox_node()

        # Flush events sent while connection was down.
        prev_event_dispatcher = self.event_dispatcher
        self.event_dispatcher = self.app.events.Dispatcher(self.connection,
                                                hostname=self.hostname,
                                                enabled=self.send_events)
        if prev_event_dispatcher:
            self.event_dispatcher.copy_buffer(prev_event_dispatcher)
            self.event_dispatcher.flush()

        # Restart heartbeat thread.
        self.restart_heartbeat()

        # reload all task's execution strategies.
        self.update_strategies()

        # We're back!
        self._state = RUN

    def restart_heartbeat(self):
        """Restart the heartbeat thread.

        This thread sends heartbeat events at intervals so monitors
        can tell if the worker is off-line/missing.

        """
        self.heart = Heart(self.timer, self.event_dispatcher)
        self.heart.start()

    def _open_connection(self):
        """Establish the broker connection.

        Will retry establishing the connection if the
        :setting:`BROKER_CONNECTION_RETRY` setting is enabled

        """
        conn = self.app.connection(heartbeat=self.amqheartbeat)

        # Callback called for each retry while the connection
        # can't be established.
        def _error_handler(exc, interval, next_step=CONNECTION_RETRY):
            if getattr(conn, 'alt', None) and interval == 0:
                next_step = CONNECTION_FAILOVER
            error(CONNECTION_ERROR, conn.as_uri(), exc,
                  next_step.format(when=humanize_seconds(interval, 'in', ' ')))

        # remember that the connection is lazy, it won't establish
        # until it's needed.
        if not self.app.conf.BROKER_CONNECTION_RETRY:
            # retry disabled, just call connect directly.
            conn.connect()
            return conn

        return conn.ensure_connection(_error_handler,
                    self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
                    callback=self.maybe_shutdown)

    def stop(self):
        """Stop consuming.

        Does not close the broker connection, so be sure to call
        :meth:`close_connection` when you are finished with it.

        """
        # Notifies other threads that this instance can't be used
        # anymore.
        self.close()
        debug('Stopping consumers...')
        self.stop_consumers(close_connection=False, join=True)

    def close(self):
        self._state = CLOSE

    def maybe_shutdown(self):
        if state.should_stop:
            raise SystemExit()
        elif state.should_terminate:
            raise SystemTerminate()

    def add_task_queue(self, queue, exchange=None, exchange_type=None,
            routing_key=None, **options):
        cset = self.task_consumer
        try:
            q = self.app.amqp.queues[queue]
        except KeyError:
            exchange = queue if exchange is None else exchange
            exchange_type = 'direct' if exchange_type is None \
                                     else exchange_type
            q = self.app.amqp.queues.select_add(queue,
                    exchange=exchange,
                    exchange_type=exchange_type,
                    routing_key=routing_key, **options)
        if not cset.consuming_from(queue):
            cset.add_queue(q)
            cset.consume()
            logger.info('Started consuming from %r', queue)

    def cancel_task_queue(self, queue):
        self.app.amqp.queues.select_remove(queue)
        self.task_consumer.cancel_by_queue(queue)

    @property
    def info(self):
        """Returns information about this consumer instance
        as a dict.

        This is also the consumer related info returned by
        ``celeryctl stats``.

        """
        conninfo = {}
        if self.connection:
            conninfo = self.connection.info()
            conninfo.pop('password', None)  # don't send password.
        return {'broker': conninfo, 'prefetch_count': self.qos.value}
예제 #26
0
파일: consumer.py 프로젝트: mindpool/celery
 def start(self, c):
     c.task_consumer = c.app.amqp.TaskConsumer(
         c.connection, on_decode_error=c.on_decode_error,
     )
     c.qos = QoS(c.task_consumer.qos, self.initial_prefetch_count)
     c.qos.update()  # set initial prefetch count
예제 #27
0
파일: test_common.py 프로젝트: celery/kombu
 def test_exceeds_short(self):
     qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
     qos.update()
     assert qos.value == PREFETCH_COUNT_MAX - 1
     qos.increment_eventually()
     assert qos.value == PREFETCH_COUNT_MAX
     qos.increment_eventually()
     assert qos.value == PREFETCH_COUNT_MAX + 1
     qos.decrement_eventually()
     assert qos.value == PREFETCH_COUNT_MAX
     qos.decrement_eventually()
     assert qos.value == PREFETCH_COUNT_MAX - 1
예제 #28
0
 def test_set(self):
     mconsumer = Mock()
     qos = QoS(mconsumer.qos, 10)
     qos.set(12)
     self.assertEqual(qos.prev, 12)
     qos.set(qos.prev)
예제 #29
0
 def test_exceeds_short(self):
     qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
     qos.update()
     assert qos.value == PREFETCH_COUNT_MAX - 1
     qos.increment_eventually()
     assert qos.value == PREFETCH_COUNT_MAX
     qos.increment_eventually()
     assert qos.value == PREFETCH_COUNT_MAX + 1
     qos.decrement_eventually()
     assert qos.value == PREFETCH_COUNT_MAX
     qos.decrement_eventually()
     assert qos.value == PREFETCH_COUNT_MAX - 1
예제 #30
0
파일: test_common.py 프로젝트: celery/kombu
    def test_consumer_increment_decrement(self):
        mconsumer = Mock()
        qos = QoS(mconsumer.qos, 10)
        qos.update()
        assert qos.value == 10
        mconsumer.qos.assert_called_with(prefetch_count=10)
        qos.decrement_eventually()
        qos.update()
        assert qos.value == 9
        mconsumer.qos.assert_called_with(prefetch_count=9)
        qos.decrement_eventually()
        assert qos.value == 8
        mconsumer.qos.assert_called_with(prefetch_count=9)
        assert {'prefetch_count': 9} in mconsumer.qos.call_args

        # Does not decrement 0 value
        qos.value = 0
        qos.decrement_eventually()
        assert qos.value == 0
        qos.increment_eventually()
        assert qos.value == 0
예제 #31
0
 def test_exceeds_short(self):
     qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
     qos.update()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
     qos.increment_eventually()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
     qos.increment_eventually()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
     qos.decrement_eventually()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
     qos.decrement_eventually()
     self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
예제 #32
0
파일: consumer.py 프로젝트: quoter/celery
class Consumer(object):
    """Listen for messages received from the broker and
    move them to the ready queue for task processing.

    :param ready_queue: See :attr:`ready_queue`.
    :param timer: See :attr:`timer`.

    """

    #: The queue that holds tasks ready for immediate processing.
    ready_queue = None

    #: Enable/disable events.
    send_events = False

    #: Optional callback to be called when the connection is established.
    #: Will only be called once, even if the connection is lost and
    #: re-established.
    init_callback = None

    #: The current hostname.  Defaults to the system hostname.
    hostname = None

    #: Initial QoS prefetch count for the task channel.
    initial_prefetch_count = 0

    #: A :class:`celery.events.EventDispatcher` for sending events.
    event_dispatcher = None

    #: The thread that sends event heartbeats at regular intervals.
    #: The heartbeats are used by monitors to detect that a worker
    #: went offline/disappeared.
    heart = None

    #: The broker connection.
    connection = None

    #: The consumer used to consume task messages.
    task_consumer = None

    #: The consumer used to consume broadcast commands.
    broadcast_consumer = None

    #: The process mailbox (kombu pidbox node).
    pidbox_node = None
    _pidbox_node_shutdown = None   # used for greenlets
    _pidbox_node_stopped = None    # used for greenlets

    #: The current worker pool instance.
    pool = None

    #: A timer used for high-priority internal tasks, such
    #: as sending heartbeats.
    timer = None

    # Consumer state, can be RUN or CLOSE.
    _state = None

    def __init__(self, ready_queue,
            init_callback=noop, send_events=False, hostname=None,
            initial_prefetch_count=2, pool=None, app=None,
            timer=None, controller=None, hub=None, amqheartbeat=None,
            **kwargs):
        self.app = app_or_default(app)
        self.connection = None
        self.task_consumer = None
        self.controller = controller
        self.broadcast_consumer = None
        self.ready_queue = ready_queue
        self.send_events = send_events
        self.init_callback = init_callback
        self.hostname = hostname or socket.gethostname()
        self.initial_prefetch_count = initial_prefetch_count
        self.event_dispatcher = None
        self.heart = None
        self.pool = pool
        self.timer = timer or timer2.default_timer
        pidbox_state = AttributeDict(app=self.app,
                                     hostname=self.hostname,
                                     listener=self,     # pre 2.2
                                     consumer=self)
        self.pidbox_node = self.app.control.mailbox.Node(self.hostname,
                                                         state=pidbox_state,
                                                         handlers=Panel.data)
        conninfo = self.app.connection()
        self.connection_errors = conninfo.connection_errors
        self.channel_errors = conninfo.channel_errors

        self._does_info = logger.isEnabledFor(logging.INFO)
        self.strategies = {}
        if hub:
            hub.on_init.append(self.on_poll_init)
        self.hub = hub
        self._quick_put = self.ready_queue.put
        self.amqheartbeat = amqheartbeat
        if self.amqheartbeat is None:
            self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
        if not hub:
            self.amqheartbeat = 0

        if _detect_environment() == 'gevent':
            # there's a gevent bug that causes timeouts to not be reset,
            # so if the connection timeout is exceeded once, it can NEVER
            # connect again.
            self.app.conf.BROKER_CONNECTION_TIMEOUT = None

    def update_strategies(self):
        S = self.strategies
        app = self.app
        loader = app.loader
        hostname = self.hostname
        for name, task in self.app.tasks.iteritems():
            S[name] = task.start_strategy(app, self)
            task.__trace__ = build_tracer(name, task, loader, hostname)

    def start(self):
        """Start the consumer.

        Automatically survives intermittent connection failure,
        and will retry establishing the connection and restart
        consuming messages.

        """

        self.init_callback(self)

        while self._state != CLOSE:
            self.maybe_shutdown()
            try:
                self.reset_connection()
                self.consume_messages()
            except self.connection_errors + self.channel_errors:
                error(RETRY_CONNECTION, exc_info=True)

    def on_poll_init(self, hub):
        hub.update_readers(self.connection.eventmap)
        self.connection.transport.on_poll_init(hub.poller)

    def consume_messages(self, sleep=sleep, min=min, Empty=Empty,
            hbrate=AMQHEARTBEAT_RATE):
        """Consume messages forever (or until an exception is raised)."""

        with self.hub as hub:
            qos = self.qos
            update_qos = qos.update
            update_readers = hub.update_readers
            readers, writers = hub.readers, hub.writers
            poll = hub.poller.poll
            fire_timers = hub.fire_timers
            scheduled = hub.timer._queue
            connection = self.connection
            hb = self.amqheartbeat
            hbtick = connection.heartbeat_check
            on_poll_start = connection.transport.on_poll_start
            on_poll_empty = connection.transport.on_poll_empty
            strategies = self.strategies
            drain_nowait = connection.drain_nowait
            on_task_callbacks = hub.on_task
            keep_draining = connection.transport.nb_keep_draining

            if hb and connection.supports_heartbeats:
                hub.timer.apply_interval(
                    hb * 1000.0 / hbrate, hbtick, (hbrate, ))

            def on_task_received(body, message):
                if on_task_callbacks:
                    [callback() for callback in on_task_callbacks]
                try:
                    name = body['task']
                except (KeyError, TypeError):
                    return self.handle_unknown_message(body, message)
                try:
                    strategies[name](message, body, message.ack_log_error)
                except KeyError as exc:
                    self.handle_unknown_task(body, message, exc)
                except InvalidTaskError as exc:
                    self.handle_invalid_task(body, message, exc)
                #fire_timers()

            self.task_consumer.callbacks = [on_task_received]
            self.task_consumer.consume()

            debug('Ready to accept tasks!')

            while self._state != CLOSE and self.connection:
                # shutdown if signal handlers told us to.
                if state.should_stop:
                    raise SystemExit()
                elif state.should_terminate:
                    raise SystemTerminate()

                # fire any ready timers, this also returns
                # the number of seconds until we need to fire timers again.
                poll_timeout = fire_timers() if scheduled else 1

                # We only update QoS when there is no more messages to read.
                # This groups together qos calls, and makes sure that remote
                # control commands will be prioritized over task messages.
                if qos.prev != qos.value:
                    update_qos()

                update_readers(on_poll_start())
                if readers or writers:
                    connection.more_to_read = True
                    while connection.more_to_read:
                        try:
                            events = poll(poll_timeout)
                        except ValueError:  # Issue 882
                            return
                        if not events:
                            on_poll_empty()
                        for fileno, event in events or ():
                            try:
                                if event & READ:
                                    readers[fileno](fileno, event)
                                if event & WRITE:
                                    writers[fileno](fileno, event)
                                if event & ERR:
                                    for handlermap in readers, writers:
                                        try:
                                            handlermap[fileno](fileno, event)
                                        except KeyError:
                                            pass
                            except (KeyError, Empty):
                                continue
                            except socket.error:
                                if self._state != CLOSE:  # pragma: no cover
                                    raise
                        if keep_draining:
                            drain_nowait()
                            poll_timeout = 0
                        else:
                            connection.more_to_read = False
                else:
                    # no sockets yet, startup is probably not done.
                    sleep(min(poll_timeout, 0.1))

    def on_task(self, task, task_reserved=task_reserved):
        """Handle received task.

        If the task has an `eta` we enter it into the ETA schedule,
        otherwise we move it the ready queue for immediate processing.

        """
        if task.revoked():
            return

        if self._does_info:
            info('Got task from broker: %s', task)

        if self.event_dispatcher.enabled:
            self.event_dispatcher.send('task-received', uuid=task.id,
                    name=task.name, args=safe_repr(task.args),
                    kwargs=safe_repr(task.kwargs),
                    retries=task.request_dict.get('retries', 0),
                    eta=task.eta and task.eta.isoformat(),
                    expires=task.expires and task.expires.isoformat())

        if task.eta:
            try:
                eta = timer2.to_timestamp(task.eta)
            except OverflowError as exc:
                error("Couldn't convert eta %s to timestamp: %r. Task: %r",
                      task.eta, exc, task.info(safe=True), exc_info=True)
                task.acknowledge()
            else:
                self.qos.increment_eventually()
                self.timer.apply_at(eta, self.apply_eta_task, (task, ),
                                    priority=6)
        else:
            task_reserved(task)
            self._quick_put(task)

    def on_control(self, body, message):
        """Process remote control command message."""
        try:
            self.pidbox_node.handle_message(body, message)
        except KeyError as exc:
            error('No such control command: %s', exc)
        except Exception as exc:
            error('Control command error: %r', exc, exc_info=True)
            self.reset_pidbox_node()

    def apply_eta_task(self, task):
        """Method called by the timer to apply a task with an
        ETA/countdown."""
        task_reserved(task)
        self._quick_put(task)
        self.qos.decrement_eventually()

    def _message_report(self, body, message):
        return MESSAGE_REPORT.format(dump_body(message, body),
                                     safe_repr(message.content_type),
                                     safe_repr(message.content_encoding),
                                     safe_repr(message.delivery_info))

    def handle_unknown_message(self, body, message):
        warn(UNKNOWN_FORMAT, self._message_report(body, message))
        message.reject_log_error(logger, self.connection_errors)

    def handle_unknown_task(self, body, message, exc):
        error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
        message.reject_log_error(logger, self.connection_errors)

    def handle_invalid_task(self, body, message, exc):
        error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
        message.reject_log_error(logger, self.connection_errors)

    def receive_message(self, body, message):
        """Handles incoming messages.

        :param body: The message body.
        :param message: The kombu message object.

        """
        try:
            name = body['task']
        except (KeyError, TypeError):
            return self.handle_unknown_message(body, message)

        try:
            self.strategies[name](message, body, message.ack_log_error)
        except KeyError as exc:
            self.handle_unknown_task(body, message, exc)
        except InvalidTaskError as exc:
            self.handle_invalid_task(body, message, exc)

    def maybe_conn_error(self, fun):
        """Applies function but ignores any connection or channel
        errors raised."""
        try:
            fun()
        except (AttributeError, ) + \
                self.connection_errors + \
                self.channel_errors:
            pass

    def close_connection(self):
        """Closes the current broker connection and all open channels."""

        # We must set self.connection to None here, so
        # that the green pidbox thread exits.
        connection, self.connection = self.connection, None

        if self.task_consumer:
            debug('Closing consumer channel...')
            self.task_consumer = \
                    self.maybe_conn_error(self.task_consumer.close)

        self.stop_pidbox_node()

        if connection:
            debug('Closing broker connection...')
            self.maybe_conn_error(connection.close)

    def stop_consumers(self, close_connection=True, join=True):
        """Stop consuming tasks and broadcast commands, also stops
        the heartbeat thread and event dispatcher.

        :keyword close_connection: Set to False to skip closing the broker
                                    connection.

        """
        if not self._state == RUN:
            return

        if self.heart:
            # Stop the heartbeat thread if it's running.
            debug('Heart: Going into cardiac arrest...')
            self.heart = self.heart.stop()

        debug('Cancelling task consumer...')
        if join and self.task_consumer:
            self.maybe_conn_error(self.task_consumer.cancel)

        if self.event_dispatcher:
            debug('Shutting down event dispatcher...')
            self.event_dispatcher = \
                    self.maybe_conn_error(self.event_dispatcher.close)

        debug('Cancelling broadcast consumer...')
        if join and self.broadcast_consumer:
            self.maybe_conn_error(self.broadcast_consumer.cancel)

        if close_connection:
            self.close_connection()

    def on_decode_error(self, message, exc):
        """Callback called if an error occurs while decoding
        a message received.

        Simply logs the error and acknowledges the message so it
        doesn't enter a loop.

        :param message: The message with errors.
        :param exc: The original exception instance.

        """
        crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')",
             exc, message.content_type, message.content_encoding,
             dump_body(message, message.body))
        message.ack()

    def reset_pidbox_node(self):
        """Sets up the process mailbox."""
        self.stop_pidbox_node()
        # close previously opened channel if any.
        if self.pidbox_node.channel:
            try:
                self.pidbox_node.channel.close()
            except self.connection_errors + self.channel_errors:
                pass

        if self.pool is not None and self.pool.is_green:
            return self.pool.spawn_n(self._green_pidbox_node)
        self.pidbox_node.channel = self.connection.channel()
        self.broadcast_consumer = self.pidbox_node.listen(
                                        callback=self.on_control)

    def stop_pidbox_node(self):
        if self._pidbox_node_stopped:
            self._pidbox_node_shutdown.set()
            debug('Waiting for broadcast thread to shutdown...')
            self._pidbox_node_stopped.wait()
            self._pidbox_node_stopped = self._pidbox_node_shutdown = None
        elif self.broadcast_consumer:
            debug('Closing broadcast channel...')
            self.broadcast_consumer = \
                self.maybe_conn_error(self.broadcast_consumer.channel.close)

    def _green_pidbox_node(self):
        """Sets up the process mailbox when running in a greenlet
        environment."""
        # THIS CODE IS TERRIBLE
        # Luckily work has already started rewriting the Consumer for 4.0.
        self._pidbox_node_shutdown = threading.Event()
        self._pidbox_node_stopped = threading.Event()
        try:
            with self._open_connection() as conn:
                info('pidbox: Connected to %s.', conn.as_uri())
                self.pidbox_node.channel = conn.default_channel
                self.broadcast_consumer = self.pidbox_node.listen(
                                            callback=self.on_control)
                with self.broadcast_consumer:
                    while not self._pidbox_node_shutdown.isSet():
                        try:
                            conn.drain_events(timeout=1.0)
                        except socket.timeout:
                            pass
        finally:
            self._pidbox_node_stopped.set()

    def reset_connection(self):
        """Re-establish the broker connection and set up consumers,
        heartbeat and the event dispatcher."""
        debug('Re-establishing connection to the broker...')
        self.stop_consumers(join=False)

        # Clear internal queues to get rid of old messages.
        # They can't be acked anyway, as a delivery tag is specific
        # to the current channel.
        self.ready_queue.clear()
        self.timer.clear()

        # Re-establish the broker connection and setup the task consumer.
        self.connection = self._open_connection()
        info('consumer: Connected to %s.', self.connection.as_uri())
        self.task_consumer = self.app.amqp.TaskConsumer(self.connection,
                                    on_decode_error=self.on_decode_error)
        # QoS: Reset prefetch window.
        self.qos = QoS(self.task_consumer, self.initial_prefetch_count)
        self.qos.update()

        # Setup the process mailbox.
        self.reset_pidbox_node()

        # Flush events sent while connection was down.
        prev_event_dispatcher = self.event_dispatcher
        self.event_dispatcher = self.app.events.Dispatcher(self.connection,
                                                hostname=self.hostname,
                                                enabled=self.send_events)
        if prev_event_dispatcher:
            self.event_dispatcher.copy_buffer(prev_event_dispatcher)
            self.event_dispatcher.flush()

        # Restart heartbeat thread.
        self.restart_heartbeat()

        # reload all task's execution strategies.
        self.update_strategies()

        # We're back!
        self._state = RUN

    def restart_heartbeat(self):
        """Restart the heartbeat thread.

        This thread sends heartbeat events at intervals so monitors
        can tell if the worker is off-line/missing.

        """
        self.heart = Heart(self.timer, self.event_dispatcher)
        self.heart.start()

    def _open_connection(self):
        """Establish the broker connection.

        Will retry establishing the connection if the
        :setting:`BROKER_CONNECTION_RETRY` setting is enabled

        """
        conn = self.app.connection(heartbeat=self.amqheartbeat)

        # Callback called for each retry while the connection
        # can't be established.
        def _error_handler(exc, interval, next_step=CONNECTION_RETRY):
            if getattr(conn, 'alt', None) and interval == 0:
                next_step = CONNECTION_FAILOVER
            error(CONNECTION_ERROR, conn.as_uri(), exc,
                  next_step.format(when=humanize_seconds(interval, 'in', ' ')))

        # remember that the connection is lazy, it won't establish
        # until it's needed.
        if not self.app.conf.BROKER_CONNECTION_RETRY:
            # retry disabled, just call connect directly.
            conn.connect()
            return conn

        return conn.ensure_connection(_error_handler,
                    self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
                    callback=self.maybe_shutdown)

    def stop(self):
        """Stop consuming.

        Does not close the broker connection, so be sure to call
        :meth:`close_connection` when you are finished with it.

        """
        # Notifies other threads that this instance can't be used
        # anymore.
        self.close()
        debug('Stopping consumers...')
        self.stop_consumers(close_connection=False, join=True)

    def close(self):
        self._state = CLOSE

    def maybe_shutdown(self):
        if state.should_stop:
            raise SystemExit()
        elif state.should_terminate:
            raise SystemTerminate()

    def add_task_queue(self, queue, exchange=None, exchange_type=None,
            routing_key=None, **options):
        cset = self.task_consumer
        try:
            q = self.app.amqp.queues[queue]
        except KeyError:
            exchange = queue if exchange is None else exchange
            exchange_type = 'direct' if exchange_type is None \
                                     else exchange_type
            q = self.app.amqp.queues.select_add(queue,
                    exchange=exchange,
                    exchange_type=exchange_type,
                    routing_key=routing_key, **options)
        if not cset.consuming_from(queue):
            cset.add_queue(q)
            cset.consume()
            logger.info('Started consuming from %r', queue)

    def cancel_task_queue(self, queue):
        self.app.amqp.queues.select_remove(queue)
        self.task_consumer.cancel_by_queue(queue)

    @property
    def info(self):
        """Returns information about this consumer instance
        as a dict.

        This is also the consumer related info returned by
        ``celeryctl stats``.

        """
        conninfo = {}
        if self.connection:
            conninfo = self.connection.info()
            conninfo.pop('password', None)  # don't send password.
        return {'broker': conninfo, 'prefetch_count': self.qos.value}