def test_qos_exceeds_16bit(self): with patch('kombu.common.logger') as logger: callback = Mock() qos = QoS(callback, 10) qos.prev = 100 qos.set(2 ** 32) self.assertTrue(logger.warn.called) callback.assert_called_with(prefetch_count=0)
def test_qos_exceeds_16bit(self): with patch('kombu.common.logger') as logger: callback = Mock() qos = QoS(callback, 10) qos.prev = 100 # cannot use 2 ** 32 because of a bug on OSX Py2.5: # https://jira.mongodb.org/browse/PYTHON-389 qos.set(4294967296) self.assertTrue(logger.warn.called) callback.assert_called_with(prefetch_count=0)
def test_consumer_decrement_eventually(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.decrement_eventually() self.assertEqual(qos.value, 9) qos.value = 0 qos.decrement_eventually() self.assertEqual(qos.value, 0)
def test_consumer_decrement_eventually(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.decrement_eventually() assert qos.value == 9 qos.value = 0 qos.decrement_eventually() assert qos.value == 0
def test_loop_ignores_socket_timeout(self): class Connection(current_app.connection().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None raise socket.timeout(10) l = MyKombuConsumer(self.ready_queue, timer=self.timer) l.connection = Connection() l.task_consumer = Mock() l.connection.obj = l l.qos = QoS(l.task_consumer, 10) l.loop(*l.loop_args())
def test_loop_ignores_socket_timeout(self): class Connection(self.app.connection().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None raise socket.timeout(10) l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.connection = Connection() l.task_consumer = Mock() l.connection.obj = l l.qos = QoS(l.task_consumer.qos, 10) l.loop(*l.loop_args())
def test_with_autoscaler_file_descriptor_safety(self): # Given: a test celery worker instance with auto scaling worker = self.create_worker( autoscale=[10, 5], use_eventloop=True, timer_cls='celery.utils.timer2.Timer', threads=False, ) # Given: This test requires a QoS defined on the worker consumer worker.consumer.qos = qos = QoS(lambda prefetch_count: prefetch_count, 2) qos.update() # Given: We have started the worker pool worker.pool.start() # Then: the worker pool is the same as the autoscaler pool auto_scaler = worker.autoscaler assert worker.pool == auto_scaler.pool # Given: Utilize kombu to get the global hub state hub = get_event_loop() # Given: Initial call the Async Pool to register events works fine worker.pool.register_with_event_loop(hub) # Create some mock queue message and read from them _keep = [Mock(name=f'req{i}') for i in range(20)] [state.task_reserved(m) for m in _keep] auto_scaler.body() # Simulate a file descriptor from the list is closed by the OS # auto_scaler.force_scale_down(5) # This actually works -- it releases the semaphore properly # Same with calling .terminate() on the process directly for fd, proc in worker.pool._pool._fileno_to_outq.items(): # however opening this fd as a file and closing it will do it queue_worker_socket = open(str(fd), "w") queue_worker_socket.close() break # Only need to do this once # When: Calling again to register with event loop ... worker.pool.register_with_event_loop(hub) # Then: test did not raise "OSError: [Errno 9] Bad file descriptor!" # Finally: Clean up so the threads before/after fixture passes worker.terminate() worker.pool.terminate()
def test_loop_when_socket_error(self): class Connection(self.app.connection_for_read().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None raise socket.error('foo') c = self.LoopConsumer() c.blueprint.state = RUN conn = c.connection = Connection() c.connection.obj = c c.qos = QoS(c.task_consumer.qos, 10) with pytest.raises(socket.error): c.loop(*c.loop_args()) c.blueprint.state = CLOSE c.connection = conn c.loop(*c.loop_args())
def test_loop_when_socket_error(self): class Connection(current_app.connection().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None raise socket.error('foo') l = Consumer(self.ready_queue, timer=self.timer) l.namespace.state = RUN c = l.connection = Connection() l.connection.obj = l l.task_consumer = Mock() l.qos = QoS(l.task_consumer.qos, 10) with self.assertRaises(socket.error): l.loop(*l.loop_args()) l.namespace.state = CLOSE l.connection = c l.loop(*l.loop_args())
def test_loop_when_socket_error(self): class Connection(self.app.connection().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None raise socket.error('foo') l = Consumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN c = l.connection = Connection() l.connection.obj = l l.task_consumer = Mock() l.qos = QoS(l.task_consumer.qos, 10) with self.assertRaises(socket.error): l.loop(*l.loop_args()) l.blueprint.state = CLOSE l.connection = c l.loop(*l.loop_args())
def reset_connection(self): """Re-establish the broker connection and set up consumers, heartbeat and the event dispatcher.""" debug('Re-establishing connection to the broker...') self.stop_consumers(join=False) # Clear internal queues to get rid of old messages. # They can't be acked anyway, as a delivery tag is specific # to the current channel. self.ready_queue.clear() self.timer.clear() # Re-establish the broker connection and setup the task consumer. self.connection = self._open_connection() info('consumer: Connected to %s.', self.connection.as_uri()) self.task_consumer = self.app.amqp.TaskConsumer(self.connection, on_decode_error=self.on_decode_error) # QoS: Reset prefetch window. self.qos = QoS(self.task_consumer, self.initial_prefetch_count) self.qos.update() # Setup the process mailbox. self.reset_pidbox_node() # Flush events sent while connection was down. prev_event_dispatcher = self.event_dispatcher self.event_dispatcher = self.app.events.Dispatcher(self.connection, hostname=self.hostname, enabled=self.send_events) if prev_event_dispatcher: self.event_dispatcher.copy_buffer(prev_event_dispatcher) self.event_dispatcher.flush() # Restart heartbeat thread. self.restart_heartbeat() # reload all task's execution strategies. self.update_strategies() # We're back! self._state = RUN
def start(self, c): c.update_strategies() # - RabbitMQ 3.3 completely redefines how basic_qos works.. # This will detect if the new qos smenatics is in effect, # and if so make sure the 'apply_global' flag is set on qos updates. qos_global = not c.connection.qos_semantics_matches_spec # set initial prefetch count c.connection.default_channel.basic_qos( 0, c.initial_prefetch_count, qos_global, ) c.task_consumer = c.app.amqp.TaskConsumer( c.connection, on_decode_error=c.on_decode_error, ) def set_prefetch_count(prefetch_count): return c.task_consumer.qos( prefetch_count=prefetch_count, apply_global=qos_global, ) c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
def test_loop(self): class Connection(current_app.connection().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None l = Consumer(self.buffer.put, timer=self.timer) l.connection = Connection() l.connection.obj = l l.task_consumer = Mock() l.qos = QoS(l.task_consumer.qos, 10) l.loop(*l.loop_args()) l.loop(*l.loop_args()) self.assertTrue(l.task_consumer.consume.call_count) l.task_consumer.qos.assert_called_with(prefetch_count=10) self.assertEqual(l.qos.value, 10) l.qos.decrement_eventually() self.assertEqual(l.qos.value, 9) l.qos.update() self.assertEqual(l.qos.value, 9) l.task_consumer.qos.assert_called_with(prefetch_count=9)
def test_loop(self): class Connection(self.app.connection_for_read().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None c = self.LoopConsumer() c.blueprint.state = RUN c.connection = Connection() c.connection.obj = c c.qos = QoS(c.task_consumer.qos, 10) c.loop(*c.loop_args()) c.loop(*c.loop_args()) self.assertTrue(c.task_consumer.consume.call_count) c.task_consumer.qos.assert_called_with(prefetch_count=10) self.assertEqual(c.qos.value, 10) c.qos.decrement_eventually() self.assertEqual(c.qos.value, 9) c.qos.update() self.assertEqual(c.qos.value, 9) c.task_consumer.qos.assert_called_with(prefetch_count=9)
def test_receieve_message_eta_isoformat(self): l = MyKombuConsumer(self.ready_queue, timer=self.timer) m = create_message(Mock(), task=foo_task.name, eta=datetime.now().isoformat(), args=[2, 4, 8], kwargs={}) l.task_consumer = Mock() l.qos = QoS(l.task_consumer, l.initial_prefetch_count) current_pcount = l.qos.value l.event_dispatcher = Mock() l.enabled = False l.update_strategies() l.receive_message(m.decode(), m) l.timer.stop() l.timer.join(1) items = [entry[2] for entry in self.timer.queue] found = 0 for item in items: if item.args[0].name == foo_task.name: found = True self.assertTrue(found) self.assertGreater(l.qos.value, current_pcount) l.timer.stop()
def test_consumer_increment_decrement(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.update() self.assertEqual(qos.value, 10) mconsumer.qos.assert_called_with(prefetch_count=10) qos.decrement_eventually() qos.update() self.assertEqual(qos.value, 9) mconsumer.qos.assert_called_with(prefetch_count=9) qos.decrement_eventually() self.assertEqual(qos.value, 8) mconsumer.qos.assert_called_with(prefetch_count=9) self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args) # Does not decrement 0 value qos.value = 0 qos.decrement_eventually() self.assertEqual(qos.value, 0) qos.increment_eventually() self.assertEqual(qos.value, 0)
def test_exceeds_short(self): qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1) qos.update() self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1) qos.increment_eventually() self.assertEqual(qos.value, PREFETCH_COUNT_MAX) qos.increment_eventually() self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1) qos.decrement_eventually() self.assertEqual(qos.value, PREFETCH_COUNT_MAX) qos.decrement_eventually() self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
def __init__(self, value): self.value = value QoS.__init__(self, None, value)
def test_set(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.set(12) assert qos.prev == 12 qos.set(qos.prev)
def test_consumer_increment_decrement(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.update() assert qos.value == 10 mconsumer.qos.assert_called_with(prefetch_count=10) qos.decrement_eventually() qos.update() assert qos.value == 9 mconsumer.qos.assert_called_with(prefetch_count=9) qos.decrement_eventually() assert qos.value == 8 mconsumer.qos.assert_called_with(prefetch_count=9) assert {'prefetch_count': 9} in mconsumer.qos.call_args # Does not decrement 0 value qos.value = 0 qos.decrement_eventually() assert qos.value == 0 qos.increment_eventually() assert qos.value == 0
def test_set(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.set(12) self.assertEqual(qos.prev, 12) qos.set(qos.prev)
class Consumer(object): """Listen for messages received from the broker and move them to the ready queue for task processing. :param ready_queue: See :attr:`ready_queue`. :param timer: See :attr:`timer`. """ #: The queue that holds tasks ready for immediate processing. ready_queue = None #: Enable/disable events. send_events = False #: Optional callback to be called when the connection is established. #: Will only be called once, even if the connection is lost and #: re-established. init_callback = None #: The current hostname. Defaults to the system hostname. hostname = None #: Initial QoS prefetch count for the task channel. initial_prefetch_count = 0 #: A :class:`celery.events.EventDispatcher` for sending events. event_dispatcher = None #: The thread that sends event heartbeats at regular intervals. #: The heartbeats are used by monitors to detect that a worker #: went offline/disappeared. heart = None #: The broker connection. connection = None #: The consumer used to consume task messages. task_consumer = None #: The consumer used to consume broadcast commands. broadcast_consumer = None #: The process mailbox (kombu pidbox node). pidbox_node = None _pidbox_node_shutdown = None # used for greenlets _pidbox_node_stopped = None # used for greenlets #: The current worker pool instance. pool = None #: A timer used for high-priority internal tasks, such #: as sending heartbeats. timer = None # Consumer state, can be RUN or CLOSE. _state = None def __init__(self, ready_queue, init_callback=noop, send_events=False, hostname=None, initial_prefetch_count=2, pool=None, app=None, timer=None, controller=None, hub=None, amqheartbeat=None, **kwargs): self.app = app_or_default(app) self.connection = None self.task_consumer = None self.controller = controller self.broadcast_consumer = None self.ready_queue = ready_queue self.send_events = send_events self.init_callback = init_callback self.hostname = hostname or socket.gethostname() self.initial_prefetch_count = initial_prefetch_count self.event_dispatcher = None self.heart = None self.pool = pool self.timer = timer or timer2.default_timer pidbox_state = AttributeDict(app=self.app, hostname=self.hostname, listener=self, # pre 2.2 consumer=self) self.pidbox_node = self.app.control.mailbox.Node(self.hostname, state=pidbox_state, handlers=Panel.data) conninfo = self.app.connection() self.connection_errors = conninfo.connection_errors self.channel_errors = conninfo.channel_errors self._does_info = logger.isEnabledFor(logging.INFO) self.strategies = {} if hub: hub.on_init.append(self.on_poll_init) self.hub = hub self._quick_put = self.ready_queue.put self.amqheartbeat = amqheartbeat if self.amqheartbeat is None: self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT if not hub: self.amqheartbeat = 0 if _detect_environment() == 'gevent': # there's a gevent bug that causes timeouts to not be reset, # so if the connection timeout is exceeded once, it can NEVER # connect again. self.app.conf.BROKER_CONNECTION_TIMEOUT = None def update_strategies(self): S = self.strategies app = self.app loader = app.loader hostname = self.hostname for name, task in self.app.tasks.iteritems(): S[name] = task.start_strategy(app, self) task.__trace__ = build_tracer(name, task, loader, hostname) def start(self): """Start the consumer. Automatically survives intermittent connection failure, and will retry establishing the connection and restart consuming messages. """ self.init_callback(self) while self._state != CLOSE: self.maybe_shutdown() try: self.reset_connection() self.consume_messages() except self.connection_errors + self.channel_errors: error(RETRY_CONNECTION, exc_info=True) def on_poll_init(self, hub): hub.update_readers(self.connection.eventmap) self.connection.transport.on_poll_init(hub.poller) def consume_messages(self, sleep=sleep, min=min, Empty=Empty, hbrate=AMQHEARTBEAT_RATE): """Consume messages forever (or until an exception is raised).""" with self.hub as hub: qos = self.qos update_qos = qos.update update_readers = hub.update_readers readers, writers = hub.readers, hub.writers poll = hub.poller.poll fire_timers = hub.fire_timers scheduled = hub.timer._queue connection = self.connection hb = self.amqheartbeat hbtick = connection.heartbeat_check on_poll_start = connection.transport.on_poll_start on_poll_empty = connection.transport.on_poll_empty strategies = self.strategies drain_nowait = connection.drain_nowait on_task_callbacks = hub.on_task keep_draining = connection.transport.nb_keep_draining if hb and connection.supports_heartbeats: hub.timer.apply_interval( hb * 1000.0 / hbrate, hbtick, (hbrate, )) def on_task_received(body, message): if on_task_callbacks: [callback() for callback in on_task_callbacks] try: name = body['task'] except (KeyError, TypeError): return self.handle_unknown_message(body, message) try: strategies[name](message, body, message.ack_log_error) except KeyError as exc: self.handle_unknown_task(body, message, exc) except InvalidTaskError as exc: self.handle_invalid_task(body, message, exc) #fire_timers() self.task_consumer.callbacks = [on_task_received] self.task_consumer.consume() debug('Ready to accept tasks!') while self._state != CLOSE and self.connection: # shutdown if signal handlers told us to. if state.should_stop: raise SystemExit() elif state.should_terminate: raise SystemTerminate() # fire any ready timers, this also returns # the number of seconds until we need to fire timers again. poll_timeout = fire_timers() if scheduled else 1 # We only update QoS when there is no more messages to read. # This groups together qos calls, and makes sure that remote # control commands will be prioritized over task messages. if qos.prev != qos.value: update_qos() update_readers(on_poll_start()) if readers or writers: connection.more_to_read = True while connection.more_to_read: try: events = poll(poll_timeout) except ValueError: # Issue 882 return if not events: on_poll_empty() for fileno, event in events or (): try: if event & READ: readers[fileno](fileno, event) if event & WRITE: writers[fileno](fileno, event) if event & ERR: for handlermap in readers, writers: try: handlermap[fileno](fileno, event) except KeyError: pass except (KeyError, Empty): continue except socket.error: if self._state != CLOSE: # pragma: no cover raise if keep_draining: drain_nowait() poll_timeout = 0 else: connection.more_to_read = False else: # no sockets yet, startup is probably not done. sleep(min(poll_timeout, 0.1)) def on_task(self, task, task_reserved=task_reserved): """Handle received task. If the task has an `eta` we enter it into the ETA schedule, otherwise we move it the ready queue for immediate processing. """ if task.revoked(): return if self._does_info: info('Got task from broker: %s', task) if self.event_dispatcher.enabled: self.event_dispatcher.send('task-received', uuid=task.id, name=task.name, args=safe_repr(task.args), kwargs=safe_repr(task.kwargs), retries=task.request_dict.get('retries', 0), eta=task.eta and task.eta.isoformat(), expires=task.expires and task.expires.isoformat()) if task.eta: eta = timezone.to_system(task.eta) if task.utc else task.eta try: eta = timer2.to_timestamp(eta) except OverflowError as exc: error("Couldn't convert eta %s to timestamp: %r. Task: %r", task.eta, exc, task.info(safe=True), exc_info=True) task.acknowledge() else: self.qos.increment_eventually() self.timer.apply_at( eta, self.apply_eta_task, (task, ), priority=6, ) else: task_reserved(task) self._quick_put(task) def on_control(self, body, message): """Process remote control command message.""" try: self.pidbox_node.handle_message(body, message) except KeyError as exc: error('No such control command: %s', exc) except Exception as exc: error('Control command error: %r', exc, exc_info=True) self.reset_pidbox_node() def apply_eta_task(self, task): """Method called by the timer to apply a task with an ETA/countdown.""" task_reserved(task) self._quick_put(task) self.qos.decrement_eventually() def _message_report(self, body, message): return MESSAGE_REPORT.format(dump_body(message, body), safe_repr(message.content_type), safe_repr(message.content_encoding), safe_repr(message.delivery_info)) def handle_unknown_message(self, body, message): warn(UNKNOWN_FORMAT, self._message_report(body, message)) message.reject_log_error(logger, self.connection_errors) def handle_unknown_task(self, body, message, exc): error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True) message.reject_log_error(logger, self.connection_errors) def handle_invalid_task(self, body, message, exc): error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True) message.reject_log_error(logger, self.connection_errors) def receive_message(self, body, message): """Handles incoming messages. :param body: The message body. :param message: The kombu message object. """ try: name = body['task'] except (KeyError, TypeError): return self.handle_unknown_message(body, message) try: self.strategies[name](message, body, message.ack_log_error) except KeyError as exc: self.handle_unknown_task(body, message, exc) except InvalidTaskError as exc: self.handle_invalid_task(body, message, exc) def maybe_conn_error(self, fun): """Applies function but ignores any connection or channel errors raised.""" try: fun() except (AttributeError, ) + \ self.connection_errors + \ self.channel_errors: pass def close_connection(self): """Closes the current broker connection and all open channels.""" # We must set self.connection to None here, so # that the green pidbox thread exits. connection, self.connection = self.connection, None if self.task_consumer: debug('Closing consumer channel...') self.task_consumer = \ self.maybe_conn_error(self.task_consumer.close) self.stop_pidbox_node() if connection: debug('Closing broker connection...') self.maybe_conn_error(connection.close) def stop_consumers(self, close_connection=True, join=True): """Stop consuming tasks and broadcast commands, also stops the heartbeat thread and event dispatcher. :keyword close_connection: Set to False to skip closing the broker connection. """ if not self._state == RUN: return if self.heart: # Stop the heartbeat thread if it's running. debug('Heart: Going into cardiac arrest...') self.heart = self.heart.stop() debug('Cancelling task consumer...') if join and self.task_consumer: self.maybe_conn_error(self.task_consumer.cancel) if self.event_dispatcher: debug('Shutting down event dispatcher...') self.event_dispatcher = \ self.maybe_conn_error(self.event_dispatcher.close) debug('Cancelling broadcast consumer...') if join and self.broadcast_consumer: self.maybe_conn_error(self.broadcast_consumer.cancel) if close_connection: self.close_connection() def on_decode_error(self, message, exc): """Callback called if an error occurs while decoding a message received. Simply logs the error and acknowledges the message so it doesn't enter a loop. :param message: The message with errors. :param exc: The original exception instance. """ crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')", exc, message.content_type, message.content_encoding, dump_body(message, message.body)) message.ack() def reset_pidbox_node(self): """Sets up the process mailbox.""" self.stop_pidbox_node() # close previously opened channel if any. if self.pidbox_node.channel: try: self.pidbox_node.channel.close() except self.connection_errors + self.channel_errors: pass if self.pool is not None and self.pool.is_green: return self.pool.spawn_n(self._green_pidbox_node) self.pidbox_node.channel = self.connection.channel() self.broadcast_consumer = self.pidbox_node.listen( callback=self.on_control) def stop_pidbox_node(self): if self._pidbox_node_stopped: self._pidbox_node_shutdown.set() debug('Waiting for broadcast thread to shutdown...') self._pidbox_node_stopped.wait() self._pidbox_node_stopped = self._pidbox_node_shutdown = None elif self.broadcast_consumer: debug('Closing broadcast channel...') self.broadcast_consumer = \ self.maybe_conn_error(self.broadcast_consumer.channel.close) def _green_pidbox_node(self): """Sets up the process mailbox when running in a greenlet environment.""" # THIS CODE IS TERRIBLE # Luckily work has already started rewriting the Consumer for 4.0. self._pidbox_node_shutdown = threading.Event() self._pidbox_node_stopped = threading.Event() try: with self._open_connection() as conn: info('pidbox: Connected to %s.', conn.as_uri()) self.pidbox_node.channel = conn.default_channel self.broadcast_consumer = self.pidbox_node.listen( callback=self.on_control) with self.broadcast_consumer: while not self._pidbox_node_shutdown.isSet(): try: conn.drain_events(timeout=1.0) except socket.timeout: pass finally: self._pidbox_node_stopped.set() def reset_connection(self): """Re-establish the broker connection and set up consumers, heartbeat and the event dispatcher.""" debug('Re-establishing connection to the broker...') self.stop_consumers(join=False) # Clear internal queues to get rid of old messages. # They can't be acked anyway, as a delivery tag is specific # to the current channel. self.ready_queue.clear() self.timer.clear() # Re-establish the broker connection and setup the task consumer. self.connection = self._open_connection() info('consumer: Connected to %s.', self.connection.as_uri()) self.task_consumer = self.app.amqp.TaskConsumer(self.connection, on_decode_error=self.on_decode_error) # QoS: Reset prefetch window. self.qos = QoS(self.task_consumer, self.initial_prefetch_count) self.qos.update() # Setup the process mailbox. self.reset_pidbox_node() # Flush events sent while connection was down. prev_event_dispatcher = self.event_dispatcher self.event_dispatcher = self.app.events.Dispatcher(self.connection, hostname=self.hostname, enabled=self.send_events) if prev_event_dispatcher: self.event_dispatcher.copy_buffer(prev_event_dispatcher) self.event_dispatcher.flush() # Restart heartbeat thread. self.restart_heartbeat() # reload all task's execution strategies. self.update_strategies() # We're back! self._state = RUN def restart_heartbeat(self): """Restart the heartbeat thread. This thread sends heartbeat events at intervals so monitors can tell if the worker is off-line/missing. """ self.heart = Heart(self.timer, self.event_dispatcher) self.heart.start() def _open_connection(self): """Establish the broker connection. Will retry establishing the connection if the :setting:`BROKER_CONNECTION_RETRY` setting is enabled """ conn = self.app.connection(heartbeat=self.amqheartbeat) # Callback called for each retry while the connection # can't be established. def _error_handler(exc, interval, next_step=CONNECTION_RETRY): if getattr(conn, 'alt', None) and interval == 0: next_step = CONNECTION_FAILOVER error(CONNECTION_ERROR, conn.as_uri(), exc, next_step.format(when=humanize_seconds(interval, 'in', ' '))) # remember that the connection is lazy, it won't establish # until it's needed. if not self.app.conf.BROKER_CONNECTION_RETRY: # retry disabled, just call connect directly. conn.connect() return conn return conn.ensure_connection(_error_handler, self.app.conf.BROKER_CONNECTION_MAX_RETRIES, callback=self.maybe_shutdown) def stop(self): """Stop consuming. Does not close the broker connection, so be sure to call :meth:`close_connection` when you are finished with it. """ # Notifies other threads that this instance can't be used # anymore. self.close() debug('Stopping consumers...') self.stop_consumers(close_connection=False, join=True) def close(self): self._state = CLOSE def maybe_shutdown(self): if state.should_stop: raise SystemExit() elif state.should_terminate: raise SystemTerminate() def add_task_queue(self, queue, exchange=None, exchange_type=None, routing_key=None, **options): cset = self.task_consumer try: q = self.app.amqp.queues[queue] except KeyError: exchange = queue if exchange is None else exchange exchange_type = 'direct' if exchange_type is None \ else exchange_type q = self.app.amqp.queues.select_add(queue, exchange=exchange, exchange_type=exchange_type, routing_key=routing_key, **options) if not cset.consuming_from(queue): cset.add_queue(q) cset.consume() logger.info('Started consuming from %r', queue) def cancel_task_queue(self, queue): self.app.amqp.queues.select_remove(queue) self.task_consumer.cancel_by_queue(queue) @property def info(self): """Returns information about this consumer instance as a dict. This is also the consumer related info returned by ``celeryctl stats``. """ conninfo = {} if self.connection: conninfo = self.connection.info() conninfo.pop('password', None) # don't send password. return {'broker': conninfo, 'prefetch_count': self.qos.value}
def start(self, c): c.task_consumer = c.app.amqp.TaskConsumer( c.connection, on_decode_error=c.on_decode_error, ) c.qos = QoS(c.task_consumer.qos, self.initial_prefetch_count) c.qos.update() # set initial prefetch count
def test_exceeds_short(self): qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1) qos.update() assert qos.value == PREFETCH_COUNT_MAX - 1 qos.increment_eventually() assert qos.value == PREFETCH_COUNT_MAX qos.increment_eventually() assert qos.value == PREFETCH_COUNT_MAX + 1 qos.decrement_eventually() assert qos.value == PREFETCH_COUNT_MAX qos.decrement_eventually() assert qos.value == PREFETCH_COUNT_MAX - 1
class Consumer(object): """Listen for messages received from the broker and move them to the ready queue for task processing. :param ready_queue: See :attr:`ready_queue`. :param timer: See :attr:`timer`. """ #: The queue that holds tasks ready for immediate processing. ready_queue = None #: Enable/disable events. send_events = False #: Optional callback to be called when the connection is established. #: Will only be called once, even if the connection is lost and #: re-established. init_callback = None #: The current hostname. Defaults to the system hostname. hostname = None #: Initial QoS prefetch count for the task channel. initial_prefetch_count = 0 #: A :class:`celery.events.EventDispatcher` for sending events. event_dispatcher = None #: The thread that sends event heartbeats at regular intervals. #: The heartbeats are used by monitors to detect that a worker #: went offline/disappeared. heart = None #: The broker connection. connection = None #: The consumer used to consume task messages. task_consumer = None #: The consumer used to consume broadcast commands. broadcast_consumer = None #: The process mailbox (kombu pidbox node). pidbox_node = None _pidbox_node_shutdown = None # used for greenlets _pidbox_node_stopped = None # used for greenlets #: The current worker pool instance. pool = None #: A timer used for high-priority internal tasks, such #: as sending heartbeats. timer = None # Consumer state, can be RUN or CLOSE. _state = None def __init__(self, ready_queue, init_callback=noop, send_events=False, hostname=None, initial_prefetch_count=2, pool=None, app=None, timer=None, controller=None, hub=None, amqheartbeat=None, **kwargs): self.app = app_or_default(app) self.connection = None self.task_consumer = None self.controller = controller self.broadcast_consumer = None self.ready_queue = ready_queue self.send_events = send_events self.init_callback = init_callback self.hostname = hostname or socket.gethostname() self.initial_prefetch_count = initial_prefetch_count self.event_dispatcher = None self.heart = None self.pool = pool self.timer = timer or timer2.default_timer pidbox_state = AttributeDict(app=self.app, hostname=self.hostname, listener=self, # pre 2.2 consumer=self) self.pidbox_node = self.app.control.mailbox.Node(self.hostname, state=pidbox_state, handlers=Panel.data) conninfo = self.app.connection() self.connection_errors = conninfo.connection_errors self.channel_errors = conninfo.channel_errors self._does_info = logger.isEnabledFor(logging.INFO) self.strategies = {} if hub: hub.on_init.append(self.on_poll_init) self.hub = hub self._quick_put = self.ready_queue.put self.amqheartbeat = amqheartbeat if self.amqheartbeat is None: self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT if not hub: self.amqheartbeat = 0 if _detect_environment() == 'gevent': # there's a gevent bug that causes timeouts to not be reset, # so if the connection timeout is exceeded once, it can NEVER # connect again. self.app.conf.BROKER_CONNECTION_TIMEOUT = None def update_strategies(self): S = self.strategies app = self.app loader = app.loader hostname = self.hostname for name, task in self.app.tasks.iteritems(): S[name] = task.start_strategy(app, self) task.__trace__ = build_tracer(name, task, loader, hostname) def start(self): """Start the consumer. Automatically survives intermittent connection failure, and will retry establishing the connection and restart consuming messages. """ self.init_callback(self) while self._state != CLOSE: self.maybe_shutdown() try: self.reset_connection() self.consume_messages() except self.connection_errors + self.channel_errors: error(RETRY_CONNECTION, exc_info=True) def on_poll_init(self, hub): hub.update_readers(self.connection.eventmap) self.connection.transport.on_poll_init(hub.poller) def consume_messages(self, sleep=sleep, min=min, Empty=Empty, hbrate=AMQHEARTBEAT_RATE): """Consume messages forever (or until an exception is raised).""" with self.hub as hub: qos = self.qos update_qos = qos.update update_readers = hub.update_readers readers, writers = hub.readers, hub.writers poll = hub.poller.poll fire_timers = hub.fire_timers scheduled = hub.timer._queue connection = self.connection hb = self.amqheartbeat hbtick = connection.heartbeat_check on_poll_start = connection.transport.on_poll_start on_poll_empty = connection.transport.on_poll_empty strategies = self.strategies drain_nowait = connection.drain_nowait on_task_callbacks = hub.on_task keep_draining = connection.transport.nb_keep_draining if hb and connection.supports_heartbeats: hub.timer.apply_interval( hb * 1000.0 / hbrate, hbtick, (hbrate, )) def on_task_received(body, message): if on_task_callbacks: [callback() for callback in on_task_callbacks] try: name = body['task'] except (KeyError, TypeError): return self.handle_unknown_message(body, message) try: strategies[name](message, body, message.ack_log_error) except KeyError as exc: self.handle_unknown_task(body, message, exc) except InvalidTaskError as exc: self.handle_invalid_task(body, message, exc) #fire_timers() self.task_consumer.callbacks = [on_task_received] self.task_consumer.consume() debug('Ready to accept tasks!') while self._state != CLOSE and self.connection: # shutdown if signal handlers told us to. if state.should_stop: raise SystemExit() elif state.should_terminate: raise SystemTerminate() # fire any ready timers, this also returns # the number of seconds until we need to fire timers again. poll_timeout = fire_timers() if scheduled else 1 # We only update QoS when there is no more messages to read. # This groups together qos calls, and makes sure that remote # control commands will be prioritized over task messages. if qos.prev != qos.value: update_qos() update_readers(on_poll_start()) if readers or writers: connection.more_to_read = True while connection.more_to_read: try: events = poll(poll_timeout) except ValueError: # Issue 882 return if not events: on_poll_empty() for fileno, event in events or (): try: if event & READ: readers[fileno](fileno, event) if event & WRITE: writers[fileno](fileno, event) if event & ERR: for handlermap in readers, writers: try: handlermap[fileno](fileno, event) except KeyError: pass except (KeyError, Empty): continue except socket.error: if self._state != CLOSE: # pragma: no cover raise if keep_draining: drain_nowait() poll_timeout = 0 else: connection.more_to_read = False else: # no sockets yet, startup is probably not done. sleep(min(poll_timeout, 0.1)) def on_task(self, task, task_reserved=task_reserved): """Handle received task. If the task has an `eta` we enter it into the ETA schedule, otherwise we move it the ready queue for immediate processing. """ if task.revoked(): return if self._does_info: info('Got task from broker: %s', task) if self.event_dispatcher.enabled: self.event_dispatcher.send('task-received', uuid=task.id, name=task.name, args=safe_repr(task.args), kwargs=safe_repr(task.kwargs), retries=task.request_dict.get('retries', 0), eta=task.eta and task.eta.isoformat(), expires=task.expires and task.expires.isoformat()) if task.eta: try: eta = timer2.to_timestamp(task.eta) except OverflowError as exc: error("Couldn't convert eta %s to timestamp: %r. Task: %r", task.eta, exc, task.info(safe=True), exc_info=True) task.acknowledge() else: self.qos.increment_eventually() self.timer.apply_at(eta, self.apply_eta_task, (task, ), priority=6) else: task_reserved(task) self._quick_put(task) def on_control(self, body, message): """Process remote control command message.""" try: self.pidbox_node.handle_message(body, message) except KeyError as exc: error('No such control command: %s', exc) except Exception as exc: error('Control command error: %r', exc, exc_info=True) self.reset_pidbox_node() def apply_eta_task(self, task): """Method called by the timer to apply a task with an ETA/countdown.""" task_reserved(task) self._quick_put(task) self.qos.decrement_eventually() def _message_report(self, body, message): return MESSAGE_REPORT.format(dump_body(message, body), safe_repr(message.content_type), safe_repr(message.content_encoding), safe_repr(message.delivery_info)) def handle_unknown_message(self, body, message): warn(UNKNOWN_FORMAT, self._message_report(body, message)) message.reject_log_error(logger, self.connection_errors) def handle_unknown_task(self, body, message, exc): error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True) message.reject_log_error(logger, self.connection_errors) def handle_invalid_task(self, body, message, exc): error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True) message.reject_log_error(logger, self.connection_errors) def receive_message(self, body, message): """Handles incoming messages. :param body: The message body. :param message: The kombu message object. """ try: name = body['task'] except (KeyError, TypeError): return self.handle_unknown_message(body, message) try: self.strategies[name](message, body, message.ack_log_error) except KeyError as exc: self.handle_unknown_task(body, message, exc) except InvalidTaskError as exc: self.handle_invalid_task(body, message, exc) def maybe_conn_error(self, fun): """Applies function but ignores any connection or channel errors raised.""" try: fun() except (AttributeError, ) + \ self.connection_errors + \ self.channel_errors: pass def close_connection(self): """Closes the current broker connection and all open channels.""" # We must set self.connection to None here, so # that the green pidbox thread exits. connection, self.connection = self.connection, None if self.task_consumer: debug('Closing consumer channel...') self.task_consumer = \ self.maybe_conn_error(self.task_consumer.close) self.stop_pidbox_node() if connection: debug('Closing broker connection...') self.maybe_conn_error(connection.close) def stop_consumers(self, close_connection=True, join=True): """Stop consuming tasks and broadcast commands, also stops the heartbeat thread and event dispatcher. :keyword close_connection: Set to False to skip closing the broker connection. """ if not self._state == RUN: return if self.heart: # Stop the heartbeat thread if it's running. debug('Heart: Going into cardiac arrest...') self.heart = self.heart.stop() debug('Cancelling task consumer...') if join and self.task_consumer: self.maybe_conn_error(self.task_consumer.cancel) if self.event_dispatcher: debug('Shutting down event dispatcher...') self.event_dispatcher = \ self.maybe_conn_error(self.event_dispatcher.close) debug('Cancelling broadcast consumer...') if join and self.broadcast_consumer: self.maybe_conn_error(self.broadcast_consumer.cancel) if close_connection: self.close_connection() def on_decode_error(self, message, exc): """Callback called if an error occurs while decoding a message received. Simply logs the error and acknowledges the message so it doesn't enter a loop. :param message: The message with errors. :param exc: The original exception instance. """ crit("Can't decode message body: %r (type:%r encoding:%r raw:%r')", exc, message.content_type, message.content_encoding, dump_body(message, message.body)) message.ack() def reset_pidbox_node(self): """Sets up the process mailbox.""" self.stop_pidbox_node() # close previously opened channel if any. if self.pidbox_node.channel: try: self.pidbox_node.channel.close() except self.connection_errors + self.channel_errors: pass if self.pool is not None and self.pool.is_green: return self.pool.spawn_n(self._green_pidbox_node) self.pidbox_node.channel = self.connection.channel() self.broadcast_consumer = self.pidbox_node.listen( callback=self.on_control) def stop_pidbox_node(self): if self._pidbox_node_stopped: self._pidbox_node_shutdown.set() debug('Waiting for broadcast thread to shutdown...') self._pidbox_node_stopped.wait() self._pidbox_node_stopped = self._pidbox_node_shutdown = None elif self.broadcast_consumer: debug('Closing broadcast channel...') self.broadcast_consumer = \ self.maybe_conn_error(self.broadcast_consumer.channel.close) def _green_pidbox_node(self): """Sets up the process mailbox when running in a greenlet environment.""" # THIS CODE IS TERRIBLE # Luckily work has already started rewriting the Consumer for 4.0. self._pidbox_node_shutdown = threading.Event() self._pidbox_node_stopped = threading.Event() try: with self._open_connection() as conn: info('pidbox: Connected to %s.', conn.as_uri()) self.pidbox_node.channel = conn.default_channel self.broadcast_consumer = self.pidbox_node.listen( callback=self.on_control) with self.broadcast_consumer: while not self._pidbox_node_shutdown.isSet(): try: conn.drain_events(timeout=1.0) except socket.timeout: pass finally: self._pidbox_node_stopped.set() def reset_connection(self): """Re-establish the broker connection and set up consumers, heartbeat and the event dispatcher.""" debug('Re-establishing connection to the broker...') self.stop_consumers(join=False) # Clear internal queues to get rid of old messages. # They can't be acked anyway, as a delivery tag is specific # to the current channel. self.ready_queue.clear() self.timer.clear() # Re-establish the broker connection and setup the task consumer. self.connection = self._open_connection() info('consumer: Connected to %s.', self.connection.as_uri()) self.task_consumer = self.app.amqp.TaskConsumer(self.connection, on_decode_error=self.on_decode_error) # QoS: Reset prefetch window. self.qos = QoS(self.task_consumer, self.initial_prefetch_count) self.qos.update() # Setup the process mailbox. self.reset_pidbox_node() # Flush events sent while connection was down. prev_event_dispatcher = self.event_dispatcher self.event_dispatcher = self.app.events.Dispatcher(self.connection, hostname=self.hostname, enabled=self.send_events) if prev_event_dispatcher: self.event_dispatcher.copy_buffer(prev_event_dispatcher) self.event_dispatcher.flush() # Restart heartbeat thread. self.restart_heartbeat() # reload all task's execution strategies. self.update_strategies() # We're back! self._state = RUN def restart_heartbeat(self): """Restart the heartbeat thread. This thread sends heartbeat events at intervals so monitors can tell if the worker is off-line/missing. """ self.heart = Heart(self.timer, self.event_dispatcher) self.heart.start() def _open_connection(self): """Establish the broker connection. Will retry establishing the connection if the :setting:`BROKER_CONNECTION_RETRY` setting is enabled """ conn = self.app.connection(heartbeat=self.amqheartbeat) # Callback called for each retry while the connection # can't be established. def _error_handler(exc, interval, next_step=CONNECTION_RETRY): if getattr(conn, 'alt', None) and interval == 0: next_step = CONNECTION_FAILOVER error(CONNECTION_ERROR, conn.as_uri(), exc, next_step.format(when=humanize_seconds(interval, 'in', ' '))) # remember that the connection is lazy, it won't establish # until it's needed. if not self.app.conf.BROKER_CONNECTION_RETRY: # retry disabled, just call connect directly. conn.connect() return conn return conn.ensure_connection(_error_handler, self.app.conf.BROKER_CONNECTION_MAX_RETRIES, callback=self.maybe_shutdown) def stop(self): """Stop consuming. Does not close the broker connection, so be sure to call :meth:`close_connection` when you are finished with it. """ # Notifies other threads that this instance can't be used # anymore. self.close() debug('Stopping consumers...') self.stop_consumers(close_connection=False, join=True) def close(self): self._state = CLOSE def maybe_shutdown(self): if state.should_stop: raise SystemExit() elif state.should_terminate: raise SystemTerminate() def add_task_queue(self, queue, exchange=None, exchange_type=None, routing_key=None, **options): cset = self.task_consumer try: q = self.app.amqp.queues[queue] except KeyError: exchange = queue if exchange is None else exchange exchange_type = 'direct' if exchange_type is None \ else exchange_type q = self.app.amqp.queues.select_add(queue, exchange=exchange, exchange_type=exchange_type, routing_key=routing_key, **options) if not cset.consuming_from(queue): cset.add_queue(q) cset.consume() logger.info('Started consuming from %r', queue) def cancel_task_queue(self, queue): self.app.amqp.queues.select_remove(queue) self.task_consumer.cancel_by_queue(queue) @property def info(self): """Returns information about this consumer instance as a dict. This is also the consumer related info returned by ``celeryctl stats``. """ conninfo = {} if self.connection: conninfo = self.connection.info() conninfo.pop('password', None) # don't send password. return {'broker': conninfo, 'prefetch_count': self.qos.value}