Esempio n. 1
0
class test_Consumer(Case):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.timer = Timer()

    def tearDown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(l.task_consumer, 10)
        info = l.info
        self.assertEqual(info['prefetch_count'], 10)
        self.assertFalse(info['broker'])

        l.connection = current_app.connection()
        info = l.info
        self.assertTrue(info['broker'])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        l.reset_connection()
        self.assertIsInstance(l.connection, Connection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close_connection=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, Connection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch('celery.worker.consumer.warn')
    def test_receive_message_unknown(self, warn):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, unknown={'baz': '!!!'})
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertTrue(warn.call_count)

    @patch('celery.worker.consumer.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=('2, 2'),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()
        l.update_strategies()
        l.qos = Mock()

        l.receive_message(m.decode(), m)
        self.assertTrue(to_timestamp.called)
        self.assertTrue(m.acknowledged)

    @patch('celery.worker.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=(1, 2),
                           kwargs='foobarbaz',
                           id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn('Received invalid task message', error.call_args[0][0])

    @patch('celery.worker.consumer.crit')
    def test_on_decode_error(self, crit):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        l.on_decode_error(message, KeyError('foo'))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})
        l.update_strategies()

        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_connection_error(self):
        class MockConsumer(BlockingConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.ready_queue,
                         timer=self.timer,
                         send_events=False,
                         pool=BasePool())
        l.connection_errors = (KeyError, )
        with self.assertRaises(SyntaxError):
            l.start()
        l.heart.stop()
        l.timer.stop()

    def test_start_channel_error(self):
        # Regression test for AMQPChannelExceptions that can occur within the
        # consumer. (i.e. 404 errors)

        class MockConsumer(BlockingConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.ready_queue,
                         timer=self.timer,
                         send_events=False,
                         pool=BasePool())

        l.channel_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()
        l.timer.stop()

    def test_consume_messages_ignores_socket_timeout(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer, 10)
        l.consume_messages()

    def test_consume_messages_when_socket_error(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10)
        with self.assertRaises(socket.error):
            l.consume_messages()

        l._state = CLOSE
        l.connection = c
        l.consume_messages()

    def test_consume_messages(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        l.task_consumer.qos = Mock()
        self.assertEqual(l.qos.value, 10)
        l.qos.decrement_eventually()
        self.assertEqual(l.qos.value, 9)
        l.qos.update()
        self.assertEqual(l.qos.value, 9)
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_maybe_conn_error(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(Mock(side_effect=AttributeError('foo')))
        l.maybe_conn_error(Mock(side_effect=KeyError('foo')))
        l.maybe_conn_error(Mock(side_effect=SyntaxError('foo')))
        with self.assertRaises(IndexError):
            l.maybe_conn_error(Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        if sys.version_info < (2, 6):
            raise SkipTest('test broken on Python 2.5')
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(),
                           task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8],
                           kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
        current_pcount = l.qos.value
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        l.receive_message(m.decode(), m)
        l.timer.stop()
        l.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertGreater(l.qos.value, current_pcount)
        l.timer.stop()

    def test_on_control(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()
        l.reset_pidbox_node = Mock()

        l.on_control('foo', 'bar')
        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = KeyError('foo')
        l.on_control('foo', 'bar')
        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = ValueError('foo')
        l.on_control('foo', 'bar')
        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
        l.reset_pidbox_node.assert_called_with()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, timer=self.timer)
        backend = Mock()
        id = uuid()
        t = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        self.assertFalse(l.receive_message(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch('celery.worker.consumer.warn')
    @patch('celery.worker.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject.assert_called_with()
        self.assertTrue(logger.critical.call_count)

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(
            backend,
            task=foo_task.name,
            args=[2, 4, 8],
            kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.timer.stop()
        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()
        chan = l.pidbox_node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error('foo')
        l.connection_errors = (socket.error, )
        l.reset_pidbox_node()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pool = Mock()
        l.pool.is_green = True
        l.reset_pidbox_node()
        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)

    def test__green_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()

        class BConsumer(Mock):
            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        l.pidbox_node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                self.obj._pidbox_node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l._open_connection = lambda: Connection(obj=l)
        l._green_pidbox_node()

        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
        self.assertTrue(l.broadcast_consumer)
        l.broadcast_consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.sleep')
    def test_open_connection_errback(self, sleep, connect):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        from kombu.transport.memory import Transport
        Transport.connection_errors = (StdChannelError, )

        def effect():
            if connect.call_count > 1:
                return
            raise StdChannelError()

        connect.side_effect = effect
        l._open_connection()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._pidbox_node_stopped = Event()
        l._pidbox_node_shutdown = Event()
        l._pidbox_node_stopped.set()
        l.stop_pidbox_node()

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError('foo')

        init_callback = Mock()
        l = _Consumer(self.ready_queue,
                      timer=self.timer,
                      init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = Connection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError('foo')

        l.consume_messages = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue,
                      timer=self.timer,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = Connection()
        l.consume_messages = Mock(side_effect=socket.error('foo'))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertTrue(l.consume_messages.call_count)

    def test_reset_connection_with_no_node(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        self.assertEqual(None, l.pool)
        l.reset_connection()

    def test_on_task_revoked(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = True
        l.on_task(task)

    def test_on_task_no_events(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = False
        l.event_dispatcher = Mock()
        l.event_dispatcher.enabled = False
        task.eta = None
        l._does_info = False
        l.on_task(task)
Esempio n. 2
0
class test_Consumer(AppCase):
    def setup(self):
        self.buffer = FastQueue()
        self.timer = Timer()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z

        self.foo_task = foo_task

    def teardown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        l.connection = Mock()
        l.connection.info.return_value = {'foo': 'bar'}
        l.controller = l.app.WorkController()
        l.controller.pool = Mock()
        l.controller.pool.info.return_value = [Mock(), Mock()]
        l.controller.consumer = l
        info = l.controller.stats()
        self.assertEqual(info['prefetch_count'], 10)
        self.assertTrue(info['broker'])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)

        l.blueprint.start(l)
        self.assertIsInstance(l.connection, Connection)

        l.blueprint.state = RUN
        l.event_dispatcher = None
        l.blueprint.restart(l)
        self.assertTrue(l.connection)

        l.blueprint.state = RUN
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.blueprint.start(l)
        self.assertIsInstance(l.connection, Connection)
        l.blueprint.restart(l)

        l.stop()
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        step = find_step(l, consumer.Connection)
        conn = l.connection = Mock()
        step.shutdown(l)
        self.assertTrue(conn.close.called)
        self.assertIsNone(l.connection)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l.blueprint.state = RUN
        Events = find_step(l, consumer.Events)
        Events.shutdown(l)
        Heart = find_step(l, consumer.Heart)
        Heart.shutdown(l)
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch('celery.worker.consumer.warn')
    def test_receive_message_unknown(self, warn):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        m = create_message(backend, unknown={'baz': '!!!'})
        l.event_dispatcher = Mock()
        l.node = MockNode()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertTrue(warn.call_count)

    @patch('celery.worker.strategy.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        m = create_message(Mock(),
                           task=self.foo_task.name,
                           args=('2, 2'),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.node = MockNode()
        l.update_strategies()
        l.qos = Mock()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertTrue(m.acknowledged)

    @patch('celery.worker.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.event_dispatcher = Mock()
        l.steps.pop()
        m = create_message(Mock(),
                           task=self.foo_task.name,
                           args=(1, 2),
                           kwargs='foobarbaz',
                           id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertIn('Received invalid task message', error.call_args[0][0])

    @patch('celery.worker.consumer.crit')
    def test_on_decode_error(self, crit):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        l.on_decode_error(message, KeyError('foo'))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def _get_on_message(self, l):
        if l.qos is None:
            l.qos = Mock()
        l.event_dispatcher = Mock()
        l.task_consumer = Mock()
        l.connection = Mock()
        l.connection.drain_events.side_effect = SystemExit()

        with self.assertRaises(SystemExit):
            l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.register_callback.called)
        return l.task_consumer.register_callback.call_args[0][0]

    def test_receieve_message(self):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.event_dispatcher = Mock()
        m = create_message(Mock(),
                           task=self.foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m.decode(), m)

        in_bucket = self.buffer.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, self.foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_channel_error(self):
        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.buffer.put,
                         timer=self.timer,
                         send_events=False,
                         pool=BasePool(),
                         app=self.app)
        l.channel_errors = (KeyError, )
        with self.assertRaises(KeyError):
            l.start()
        l.timer.stop()

    def test_start_connection_error(self):
        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.buffer.put,
                         timer=self.timer,
                         send_events=False,
                         pool=BasePool(),
                         app=self.app)

        l.connection_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.timer.stop()

    def test_loop_ignores_socket_timeout(self):
        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer.qos, 10)
        l.loop(*l.loop_args())

    def test_loop_when_socket_error(self):
        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        with self.assertRaises(socket.error):
            l.loop(*l.loop_args())

        l.blueprint.state = CLOSE
        l.connection = c
        l.loop(*l.loop_args())

    def test_loop(self):
        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)

        l.loop(*l.loop_args())
        l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        self.assertEqual(l.qos.value, 10)
        l.qos.decrement_eventually()
        self.assertEqual(l.qos.value, 9)
        l.qos.update()
        self.assertEqual(l.qos.value, 9)
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_ignore_errors(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection_errors = (
            AttributeError,
            KeyError,
        )
        l.channel_errors = (SyntaxError, )
        ignore_errors(l, Mock(side_effect=AttributeError('foo')))
        ignore_errors(l, Mock(side_effect=KeyError('foo')))
        ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
        with self.assertRaises(IndexError):
            ignore_errors(l, Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.buffer.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        m = create_message(
            Mock(),
            task=self.foo_task.name,
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
            args=[2, 4, 8],
            kwargs={},
        )

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 1)
        current_pcount = l.qos.value
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m.decode(), m)
        l.timer.stop()
        l.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == self.foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertGreater(l.qos.value, current_pcount)
        l.timer.stop()

    def test_pidbox_callback(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        con.reset = Mock()

        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = KeyError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = ValueError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')
        self.assertTrue(con.reset.called)

    def test_revoke(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        id = uuid()
        t = create_message(backend,
                           task=self.foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        callback = self._get_on_message(l)
        callback(t.decode(), t)
        self.assertTrue(self.buffer.empty())

    def test_receieve_message_not_registered(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        callback = self._get_on_message(l)
        self.assertFalse(callback(m.decode(), m))
        with self.assertRaises(Empty):
            self.buffer.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch('celery.worker.consumer.warn')
    @patch('celery.worker.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        callback = self._get_on_message(l)
        self.assertFalse(callback(m.decode(), m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.buffer.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject.assert_called_with(requeue=False)
        self.assertTrue(logger.critical.call_count)

    def test_receive_message_eta(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.steps.pop()
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(
            backend,
            task=self.foo_task.name,
            args=[2, 4, 8],
            kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )

        try:
            l.blueprint.start(l)
            p = l.app.conf.BROKER_CONNECTION_RETRY
            l.app.conf.BROKER_CONNECTION_RETRY = False
            l.blueprint.start(l)
            l.app.conf.BROKER_CONNECTION_RETRY = p
            l.blueprint.restart(l)
            l.event_dispatcher = Mock()
            callback = self._get_on_message(l)
            callback(m.decode(), m)
        finally:
            l.timer.stop()
            l.timer.join()

        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, self.foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.buffer.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        chan = con.node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error('foo')
        l.connection_errors = (socket.error, )
        con.reset()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        from celery.worker.pidbox import gPidbox
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.buffer.put,
                            timer=self.timer,
                            pool=pool,
                            app=self.app)
        con = find_step(l, consumer.Control)
        self.assertIsInstance(con.box, gPidbox)
        con.start(l)
        l.pool.spawn_n.assert_called_with(
            con.box.loop,
            l,
        )

    def test__green_pidbox_node(self):
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.buffer.put,
                            timer=self.timer,
                            pool=pool,
                            app=self.app)
        l.node = Mock()
        controller = find_step(l, consumer.Control)

        class BConsumer(Mock):
            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        controller.box.node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                controller.box._node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l.connect = lambda: Connection(obj=l)
        controller = find_step(l, consumer.Control)
        controller.box.loop(l)

        self.assertTrue(controller.box.node.listen.called)
        self.assertTrue(controller.box.consumer)
        controller.box.consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.sleep')
    def test_connect_errback(self, sleep, connect):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        from kombu.transport.memory import Transport
        Transport.connection_errors = (ChannelError, )

        def effect():
            if connect.call_count > 1:
                return
            raise ChannelError('error')

        connect.side_effect = effect
        l.connect()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        cont = find_step(l, consumer.Control)
        cont._node_stopped = Event()
        cont._node_shutdown = Event()
        cont._node_stopped.set()
        cont.stop(l)

    def test_start__loop(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError('foo')

        init_callback = Mock()
        l = _Consumer(self.buffer.put,
                      timer=self.timer,
                      init_callback=init_callback,
                      app=self.app)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = Connection()
        l.iterations = 0

        def raises_KeyError(*args, **kwargs):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError('foo')

        l.loop = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertEqual(l.iterations, 2)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.buffer.put,
                      timer=self.timer,
                      app=self.app,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = Connection()
        l.loop = Mock(side_effect=socket.error('foo'))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(l.loop.call_count)

    def test_reset_connection_with_no_node(self):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.steps.pop()
        self.assertEqual(None, l.pool)
        l.blueprint.start(l)
Esempio n. 3
0
class test_Consumer(unittest.TestCase):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = app_or_default().log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_eta_OverflowError(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        backend = MockBackend()
        called = [False]

        def to_timestamp(d):
            called[0] = True
            raise OverflowError()

        m = create_message(backend, task=foo_task.name,
                                    args=("2, 2"),
                                    kwargs={},
                                    eta=datetime.now().isoformat())
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        from celery.worker import consumer
        prev, consumer.to_timestamp = consumer.to_timestamp, to_timestamp
        try:
            l.receive_message(m.decode(), m)
            self.assertTrue(m.acknowledged)
            self.assertTrue(called[0])
        finally:
            consumer.to_timestamp = prev

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
            args=(1, 2), kwargs="foobarbaz", id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Message decoding error", logger.logged[0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta_isoformat(self):

        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        dispatcher = l.event_dispatcher = MockEventDispatcher()
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={},
                           eta=(datetime.now() +
                               timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        self.assertTrue(dispatcher.flushed)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):

        class _QoS(object):
            prev = 3
            next = 4

            def update(self):
                self.prev = self.next

        class _Consumer(MyKombuConsumer):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        called_back = [False]

        def init_callback(consumer):
            called_back[0] = True

        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.task_consumer = MockConsumer()
        l.qos = _QoS()
        l.connection = BrokerConnection()

        def raises_KeyError(limit=None):
            yield True
            l.iterations = 1
            raise KeyError("foo")

        l._mainloop = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.next)

        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = MockConsumer()
        l.connection = BrokerConnection()

        def raises_socket_error(limit=None):
            yield True
            l.iterations = 1
            raise socket.error("foo")

        l._mainloop = raises_socket_error
        self.assertRaises(socket.error, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Esempio n. 4
0
class test_Consumer(Case):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.timer = Timer()

    def tearDown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(l.task_consumer, 10)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = current_app.broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close_connection=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch("celery.worker.consumer.warn")
    def test_receive_message_unknown(self, warn):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertTrue(warn.call_count)

    @patch("celery.utils.timer2.to_timestamp")
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                                   args=("2, 2"),
                                   kwargs={},
                                   eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()
        l.update_strategies()

        l.receive_message(m.decode(), m)
        self.assertTrue(m.acknowledged)
        self.assertTrue(to_timestamp.call_count)

    @patch("celery.worker.consumer.error")
    def test_receive_message_InvalidTaskError(self, error):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                           args=(1, 2), kwargs="foobarbaz", id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Received invalid task message", error.call_args[0][0])

    @patch("celery.worker.consumer.crit")
    def test_on_decode_error(self, crit):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        class MockMessage(Mock):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                           args=[2, 4, 8], kwargs={})
        l.update_strategies()

        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_connection_error(self):

        class MockConsumer(BlockingConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue, timer=self.timer,
                             send_events=False, pool=BasePool())
        l.connection_errors = (KeyError, )
        with self.assertRaises(SyntaxError):
            l.start()
        l.heart.stop()
        l.timer.stop()

    def test_start_channel_error(self):
        # Regression test for AMQPChannelExceptions that can occur within the
        # consumer. (i.e. 404 errors)

        class MockConsumer(BlockingConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue, timer=self.timer,
                             send_events=False, pool=BasePool())

        l.channel_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()
        l.timer.stop()

    def test_consume_messages_ignores_socket_timeout(self):

        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer, 10)
        l.consume_messages()

    def test_consume_messages_when_socket_error(self):

        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error("foo")

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10)
        with self.assertRaises(socket.error):
            l.consume_messages()

        l._state = CLOSE
        l.connection = c
        l.consume_messages()

    def test_consume_messages(self):

        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        l.qos.decrement()
        l.consume_messages()
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_maybe_conn_error(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
        with self.assertRaises(IndexError):
            l.maybe_conn_error(Mock(side_effect=IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        l.receive_message(m.decode(), m)
        l.timer.stop()

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.qos.call_count)
        l.timer.stop()

    def test_on_control(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()
        l.reset_pidbox_node = Mock()

        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = KeyError("foo")
        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = ValueError("foo")
        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
        l.reset_pidbox_node.assert_called_with()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, timer=self.timer)
        backend = Mock()
        id = uuid()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        self.assertFalse(l.receive_message(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch("celery.worker.consumer.warn")
    @patch("celery.worker.consumer.logger")
    def test_receieve_message_ack_raises(self, logger, warn):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        m.reject = Mock()
        m.reject.side_effect = socket.error("foo")
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject.assert_called_with()
        self.assertTrue(logger.critical.call_count)

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={},
                           eta=(datetime.now() +
                               timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.timer.stop()
        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()
        chan = l.pidbox_node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error("foo")
        l.connection_errors = (socket.error, )
        l.reset_pidbox_node()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pool = Mock()
        l.pool.is_green = True
        l.reset_pidbox_node()
        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)

    def test__green_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()

        class BConsumer(Mock):

            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        l.pidbox_node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                self.obj._pidbox_node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l._open_connection = lambda: Connection(obj=l)
        l._green_pidbox_node()

        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
        self.assertTrue(l.broadcast_consumer)
        l.broadcast_consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch("kombu.connection.BrokerConnection._establish_connection")
    @patch("kombu.utils.sleep")
    def test_open_connection_errback(self, sleep, connect):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        from kombu.transport.memory import Transport
        Transport.connection_errors = (StdChannelError, )

        def effect():
            if connect.call_count > 1:
                return
            raise StdChannelError()
        connect.side_effect = effect
        l._open_connection()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._pidbox_node_stopped = Event()
        l._pidbox_node_shutdown = Event()
        l._pidbox_node_stopped.set()
        l.stop_pidbox_node()

    def test_start__consume_messages(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        init_callback = Mock()
        l = _Consumer(self.ready_queue, timer=self.timer,
                      init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue, timer=self.timer,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = BrokerConnection()
        l.consume_messages = Mock(side_effect=socket.error("foo"))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertTrue(l.consume_messages.call_count)

    def test_reset_connection_with_no_node(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        self.assertEqual(None, l.pool)
        l.reset_connection()

    def test_on_task_revoked(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = True
        l.on_task(task)

    def test_on_task_no_events(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = False
        l.event_dispatcher = Mock()
        l.event_dispatcher.enabled = False
        task.eta = None
        l._does_info = False
        l.on_task(task)
Esempio n. 5
0
class test_Consumer(ConsumerCase):

    def setup(self):
        self.buffer = FastQueue()
        self.timer = Timer()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z
        self.foo_task = foo_task

    def teardown(self):
        self.timer.stop()

    def LoopConsumer(self, buffer=None, controller=None, timer=None, app=None,
                     without_mingle=True, without_gossip=True,
                     without_heartbeat=True, **kwargs):
        if controller is None:
            controller = Mock(name='.controller')
        buffer = buffer if buffer is not None else self.buffer.put
        timer = timer if timer is not None else self.timer
        app = app if app is not None else self.app
        c = Consumer(
            buffer,
            timer=timer,
            app=app,
            controller=controller,
            without_mingle=without_mingle,
            without_gossip=without_gossip,
            without_heartbeat=without_heartbeat,
            **kwargs
        )
        c.task_consumer = Mock(name='.task_consumer')
        c.qos = QoS(c.task_consumer.qos, 10)
        c.connection = Mock(name='.connection')
        c.controller = c.app.WorkController()
        c.heart = Mock(name='.heart')
        c.controller.consumer = c
        c.pool = c.controller.pool = Mock(name='.controller.pool')
        c.node = Mock(name='.node')
        c.event_dispatcher = mock_event_dispatcher()
        return c

    def NoopConsumer(self, *args, **kwargs):
        c = self.LoopConsumer(*args, **kwargs)
        c.loop = Mock(name='.loop')
        return c

    def test_info(self):
        c = self.NoopConsumer()
        c.connection.info.return_value = {'foo': 'bar'}
        c.controller.pool.info.return_value = [Mock(), Mock()]
        info = c.controller.stats()
        assert info['prefetch_count'] == 10
        assert info['broker']

    def test_start_when_closed(self):
        c = self.NoopConsumer()
        c.blueprint.state = CLOSE
        c.start()

    def test_connection(self):
        c = self.NoopConsumer()

        c.blueprint.start(c)
        assert isinstance(c.connection, Connection)

        c.blueprint.state = RUN
        c.event_dispatcher = None
        c.blueprint.restart(c)
        assert c.connection

        c.blueprint.state = RUN
        c.shutdown()
        assert c.connection is None
        assert c.task_consumer is None

        c.blueprint.start(c)
        assert isinstance(c.connection, Connection)
        c.blueprint.restart(c)

        c.stop()
        c.shutdown()
        assert c.connection is None
        assert c.task_consumer is None

    def test_close_connection(self):
        c = self.NoopConsumer()
        c.blueprint.state = RUN
        step = find_step(c, consumer.Connection)
        connection = c.connection
        step.shutdown(c)
        connection.close.assert_called()
        assert c.connection is None

    def test_close_connection__heart_shutdown(self):
        c = self.NoopConsumer()
        event_dispatcher = c.event_dispatcher
        heart = c.heart
        c.event_dispatcher.enabled = True
        c.blueprint.state = RUN
        Events = find_step(c, consumer.Events)
        Events.shutdown(c)
        Heart = find_step(c, consumer.Heart)
        Heart.shutdown(c)
        event_dispatcher.close.assert_called()
        heart.stop.assert_called_with()

    @patch('celery.worker.consumer.consumer.warn')
    def test_receive_message_unknown(self, warn):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='.channeol')
        m = create_message(channel, unknown={'baz': '!!!'})

        callback = self._get_on_message(c)
        callback(m)
        warn.assert_called()

    @patch('celery.worker.strategy.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=('2, 2'), kwargs={},
            eta=datetime.now().isoformat(),
        )
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)
        assert m.acknowledged

    @patch('celery.worker.consumer.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=(1, 2), kwargs='foobarbaz', id=1)
        c.update_strategies()
        strat = c.strategies[self.foo_task.name] = Mock(name='strategy')
        strat.side_effect = InvalidTaskError()

        callback = self._get_on_message(c)
        callback(m)
        error.assert_called()
        assert 'Received invalid task message' in error.call_args[0][0]

    @patch('celery.worker.consumer.consumer.crit')
    def test_on_decode_error(self, crit):
        c = self.LoopConsumer()

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        c.on_decode_error(message, KeyError('foo'))
        assert message.ack.call_count
        assert "Can't decode message body" in crit.call_args[0][0]

    def _get_on_message(self, c):
        if c.qos is None:
            c.qos = Mock()
        c.task_consumer = Mock()
        c.event_dispatcher = mock_event_dispatcher()
        c.connection = Mock(name='.connection')
        c.connection.get_heartbeat_interval.return_value = 0
        c.connection.drain_events.side_effect = WorkerShutdown()

        with pytest.raises(WorkerShutdown):
            c.loop(*c.loop_args())
        assert c.task_consumer.on_message
        return c.task_consumer.on_message

    def test_receieve_message(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=[2, 4, 8], kwargs={},
        )
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)

        in_bucket = self.buffer.get_nowait()
        assert isinstance(in_bucket, Request)
        assert in_bucket.name == self.foo_task.name
        assert in_bucket.execute() == 2 * 4 * 8
        assert self.timer.empty()

    def test_start_channel_error(self):
        c = self.NoopConsumer(task_events=False, pool=BasePool())
        c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
        c.channel_errors = (KeyError,)
        try:
            with pytest.raises(KeyError):
                c.start()
        finally:
            c.timer and c.timer.stop()

    def test_start_connection_error(self):
        c = self.NoopConsumer(task_events=False, pool=BasePool())
        c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
        c.connection_errors = (KeyError,)
        try:
            with pytest.raises(SyntaxError):
                c.start()
        finally:
            c.timer and c.timer.stop()

    def test_loop_ignores_socket_timeout(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        c = self.NoopConsumer()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.qos = QoS(c.task_consumer.qos, 10)
        c.loop(*c.loop_args())

    def test_loop_when_socket_error(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        c = self.LoopConsumer()
        c.blueprint.state = RUN
        conn = c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.qos = QoS(c.task_consumer.qos, 10)
        with pytest.raises(socket.error):
            c.loop(*c.loop_args())

        c.blueprint.state = CLOSE
        c.connection = conn
        c.loop(*c.loop_args())

    def test_loop(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

            @property
            def supports_heartbeats(self):
                return False

        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.qos = QoS(c.task_consumer.qos, 10)

        c.loop(*c.loop_args())
        c.loop(*c.loop_args())
        assert c.task_consumer.consume.call_count
        c.task_consumer.qos.assert_called_with(prefetch_count=10)
        assert c.qos.value == 10
        c.qos.decrement_eventually()
        assert c.qos.value == 9
        c.qos.update()
        assert c.qos.value == 9
        c.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_ignore_errors(self):
        c = self.NoopConsumer()
        c.connection_errors = (AttributeError, KeyError,)
        c.channel_errors = (SyntaxError,)
        ignore_errors(c, Mock(side_effect=AttributeError('foo')))
        ignore_errors(c, Mock(side_effect=KeyError('foo')))
        ignore_errors(c, Mock(side_effect=SyntaxError('foo')))
        with pytest.raises(IndexError):
            ignore_errors(c, Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        c = self.NoopConsumer()
        c.qos = QoS(None, 10)
        task = Mock(name='task', id='1234213')
        qos = c.qos.value
        c.apply_eta_task(task)
        assert task in state.reserved_requests
        assert c.qos.value == qos - 1
        assert self.buffer.get_nowait() is task

    def test_receieve_message_eta_isoformat(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
            args=[2, 4, 8], kwargs={},
        )

        c.qos = QoS(c.task_consumer.qos, 1)
        current_pcount = c.qos.value
        c.event_dispatcher.enabled = False
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)
        c.timer.stop()
        c.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == self.foo_task.name:
                found = True
        assert found
        assert c.qos.value > current_pcount
        c.timer.stop()

    def test_pidbox_callback(self):
        c = self.NoopConsumer()
        con = find_step(c, consumer.Control).box
        con.node = Mock()
        con.reset = Mock()

        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = KeyError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = ValueError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')
        con.reset.assert_called()

    def test_revoke(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='channel')
        id = uuid()
        t = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={}, id=id,
        )

        state.revoked.add(id)

        callback = self._get_on_message(c)
        callback(t)
        assert self.buffer.empty()

    def test_receieve_message_not_registered(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='channel')
        m = self.create_task_message(
            channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
        )

        callback = self._get_on_message(c)
        assert not callback(m)
        with pytest.raises(Empty):
            self.buffer.get_nowait()
        assert self.timer.empty()

    @patch('celery.worker.consumer.consumer.warn')
    @patch('celery.worker.consumer.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        channel = Mock(name='channel')
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={},
        )
        m.headers = None

        c.update_strategies()
        c.connection_errors = (socket.error,)
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        callback = self._get_on_message(c)
        assert not callback(m)
        warn.assert_called()
        with pytest.raises(Empty):
            self.buffer.get_nowait()
        assert self.timer.empty()
        m.reject_log_error.assert_called_with(logger, c.connection_errors)

    def test_receive_message_eta(self):
        if os.environ.get('C_DEBUG_TEST'):
            pp = partial(print, file=sys.__stderr__)
        else:
            def pp(*args, **kwargs):
                pass
        pp('TEST RECEIVE MESSAGE ETA')
        pp('+CREATE MYKOMBUCONSUMER')
        c = self.LoopConsumer()
        pp('-CREATE MYKOMBUCONSUMER')
        c.steps.pop()
        channel = Mock(name='channel')
        pp('+ CREATE MESSAGE')
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )
        pp('- CREATE MESSAGE')

        try:
            pp('+ BLUEPRINT START 1')
            c.blueprint.start(c)
            pp('- BLUEPRINT START 1')
            p = c.app.conf.broker_connection_retry
            c.app.conf.broker_connection_retry = False
            pp('+ BLUEPRINT START 2')
            c.blueprint.start(c)
            pp('- BLUEPRINT START 2')
            c.app.conf.broker_connection_retry = p
            pp('+ BLUEPRINT RESTART')
            c.blueprint.restart(c)
            pp('- BLUEPRINT RESTART')
            pp('+ GET ON MESSAGE')
            callback = self._get_on_message(c)
            pp('- GET ON MESSAGE')
            pp('+ CALLBACK')
            callback(m)
            pp('- CALLBACK')
        finally:
            pp('+ STOP TIMER')
            c.timer.stop()
            pp('- STOP TIMER')
            try:
                pp('+ JOIN TIMER')
                c.timer.join()
                pp('- JOIN TIMER')
            except RuntimeError:
                pass

        in_hold = c.timer.queue[0]
        assert len(in_hold) == 3
        eta, priority, entry = in_hold
        task = entry.args[0]
        assert isinstance(task, Request)
        assert task.name == self.foo_task.name
        assert task.execute() == 2 * 4 * 8
        with pytest.raises(Empty):
            self.buffer.get_nowait()

    def test_reset_pidbox_node(self):
        c = self.NoopConsumer()
        con = find_step(c, consumer.Control).box
        con.node = Mock()
        chan = con.node.channel = Mock()
        chan.close.side_effect = socket.error('foo')
        c.connection_errors = (socket.error,)
        con.reset()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        c = self.NoopConsumer(pool=Mock(is_green=True))
        con = find_step(c, consumer.Control)
        assert isinstance(con.box, gPidbox)
        con.start(c)
        c.pool.spawn_n.assert_called_with(con.box.loop, c)

    def test_green_pidbox_node(self):
        pool = Mock()
        pool.is_green = True
        c = self.NoopConsumer(pool=Mock(is_green=True))
        controller = find_step(c, consumer.Control)

        class BConsumer(Mock):

            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        controller.box.node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                controller.box._node_shutdown.set()

            def close(self):
                self.closed = True

        c.connection_for_read = lambda: Connection(obj=c)
        controller = find_step(c, consumer.Control)
        controller.box.loop(c)

        controller.box.node.listen.assert_called()
        assert controller.box.consumer
        controller.box.consumer.consume.assert_called_with()

        assert c.connection is None
        assert connections[0].closed

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.functional.sleep')
    def test_connect_errback(self, sleep, connect):
        c = self.NoopConsumer()
        Transport.connection_errors = (ChannelError,)
        connect.on_nth_call_do(ChannelError('error'), n=1)
        c.connect()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        c = self.NoopConsumer()
        cont = find_step(c, consumer.Control)
        cont._node_stopped = Event()
        cont._node_shutdown = Event()
        cont._node_stopped.set()
        cont.stop(c)

    def test_start__loop(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        init_callback = Mock(name='init_callback')
        c = self.NoopConsumer(init_callback=init_callback)
        c.qos = _QoS()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.iterations = 0

        def raises_KeyError(*args, **kwargs):
            c.iterations += 1
            if c.qos.prev != c.qos.value:
                c.qos.update()
            if c.iterations >= 2:
                raise KeyError('foo')

        c.loop = raises_KeyError
        with pytest.raises(KeyError):
            c.start()
        assert c.iterations == 2
        assert c.qos.prev == c.qos.value

        init_callback.reset_mock()
        c = self.NoopConsumer(task_events=False, init_callback=init_callback)
        c.qos = _QoS()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.loop = Mock(side_effect=socket.error('foo'))
        with pytest.raises(socket.error):
            c.start()
        c.loop.assert_called()

    def test_reset_connection_with_no_node(self):
        c = self.NoopConsumer()
        c.steps.pop()
        c.blueprint.start(c)
Esempio n. 6
0
class test_Consumer(Case):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = current_app.log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.qos = QoS(l.task_consumer, 10, l.logger)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = current_app.broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l._state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close_connection=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = Mock()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        with self.assertWarnsRegex(RuntimeWarning, r'unknown message'):
            l.receive_message(m.decode(), m)

    @patch("celery.utils.timer2.to_timestamp")
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=("2, 2"),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()
        l.update_strategies()

        l.receive_message(m.decode(), m)
        self.assertTrue(m.acknowledged)
        self.assertTrue(to_timestamp.call_count)

    def test_receive_message_InvalidTaskError(self):
        logger = Mock()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            logger,
                            send_events=False)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=(1, 2),
                           kwargs="foobarbaz",
                           id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Received invalid task message",
                      logger.error.call_args[0][0])

    def test_on_decode_error(self):
        logger = Mock()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            logger,
                            send_events=False)

        class MockMessage(Mock):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body",
                      logger.critical.call_args[0][0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})
        l.update_strategies()

        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_start_connection_error(self):
        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue,
                         self.eta_schedule,
                         self.logger,
                         send_events=False,
                         pool=BasePool())
        l.connection_errors = (KeyError, )
        with self.assertRaises(SyntaxError):
            l.start()
        l.heart.stop()
        l.priority_timer.stop()

    def test_start_channel_error(self):
        # Regression test for AMQPChannelExceptions that can occur within the
        # consumer. (i.e. 404 errors)

        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue,
                         self.eta_schedule,
                         self.logger,
                         send_events=False,
                         pool=BasePool())

        l.channel_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()
        l.priority_timer.stop()

    def test_consume_messages_ignores_socket_timeout(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer, 10, l.logger)
        l.consume_messages()

    def test_consume_messages_when_socket_error(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error("foo")

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l._state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10, l.logger)
        with self.assertRaises(socket.error):
            l.consume_messages()

        l._state = CLOSE
        l.connection = c
        l.consume_messages()

    def test_consume_messages(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10, l.logger)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        l.qos.decrement()
        l.consume_messages()
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_maybe_conn_error(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
        with self.assertRaises(IndexError):
            l.maybe_conn_error(Mock(side_effect=IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.qos = QoS(None, 10, l.logger)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        m = create_message(Mock(),
                           task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8],
                           kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.qos.call_count)
        l.eta_schedule.stop()

    def test_on_control(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.pidbox_node = Mock()
        l.reset_pidbox_node = Mock()

        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = KeyError("foo")
        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = ValueError("foo")
        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
        l.reset_pidbox_node.assert_called_with()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = Mock()
        id = uuid()
        t = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = Mock()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        self.assertFalse(l.receive_message(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_ack_raises(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        l.logger = Mock()
        m.ack = Mock()
        m.ack.side_effect = socket.error("foo")
        with self.assertWarnsRegex(RuntimeWarning, r'unknown message'):
            self.assertFalse(l.receive_message(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.eta_schedule.empty())
        m.ack.assert_called_with()
        self.assertTrue(l.logger.critical.call_count)

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           eta=(datetime.now() +
                                timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.pidbox_node = Mock()
        chan = l.pidbox_node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error("foo")
        l.connection_errors = (socket.error, )
        l.reset_pidbox_node()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.pool = Mock()
        l.pool.is_green = True
        l.reset_pidbox_node()
        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)

    def test__green_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.pidbox_node = Mock()

        class BConsumer(Mock):
            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        l.pidbox_node.listen = BConsumer()
        connections = []

        class Connection(object):
            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def drain_events(self, **kwargs):
                self.obj.connection = None
                self.obj._pidbox_node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l._open_connection = lambda: Connection(obj=l)
        l._green_pidbox_node()

        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
        self.assertTrue(l.broadcast_consumer)
        l.broadcast_consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        init_callback = Mock()
        l = _Consumer(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = BrokerConnection()
        l.consume_messages = Mock(side_effect=socket.error("foo"))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertTrue(l.consume_messages.call_count)

    def test_reset_connection_with_no_node(self):

        l = MainConsumer(self.ready_queue, self.eta_schedule, self.logger)
        self.assertEqual(None, l.pool)
        l.reset_connection()
Esempio n. 7
0
class test_Consumer(AppCase):

    def setup(self):
        self.buffer = FastQueue()
        self.timer = Timer()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z
        self.foo_task = foo_task

    def teardown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        l.connection = Mock()
        l.connection.info.return_value = {'foo': 'bar'}
        l.controller = l.app.WorkController()
        l.controller.pool = Mock()
        l.controller.pool.info.return_value = [Mock(), Mock()]
        l.controller.consumer = l
        info = l.controller.stats()
        self.assertEqual(info['prefetch_count'], 10)
        self.assertTrue(info['broker'])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)

        l.blueprint.start(l)
        self.assertIsInstance(l.connection, Connection)

        l.blueprint.state = RUN
        l.event_dispatcher = None
        l.blueprint.restart(l)
        self.assertTrue(l.connection)

        l.blueprint.state = RUN
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.blueprint.start(l)
        self.assertIsInstance(l.connection, Connection)
        l.blueprint.restart(l)

        l.stop()
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        step = find_step(l, consumer.Connection)
        conn = l.connection = Mock()
        step.shutdown(l)
        self.assertTrue(conn.close.called)
        self.assertIsNone(l.connection)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l.blueprint.state = RUN
        Events = find_step(l, consumer.Events)
        Events.shutdown(l)
        Heart = find_step(l, consumer.Heart)
        Heart.shutdown(l)
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch('celery.worker.consumer.warn')
    def test_receive_message_unknown(self, warn):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        m = create_message(backend, unknown={'baz': '!!!'})
        l.event_dispatcher = Mock()
        l.node = MockNode()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertTrue(warn.call_count)

    @patch('celery.worker.strategy.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        m = create_message(Mock(), task=self.foo_task.name,
                           args=('2, 2'),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.node = MockNode()
        l.update_strategies()
        l.qos = Mock()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertTrue(m.acknowledged)

    @patch('celery.worker.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.event_dispatcher = Mock()
        l.steps.pop()
        m = create_message(Mock(), task=self.foo_task.name,
                           args=(1, 2), kwargs='foobarbaz', id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertIn('Received invalid task message', error.call_args[0][0])

    @patch('celery.worker.consumer.crit')
    def test_on_decode_error(self, crit):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        l.on_decode_error(message, KeyError('foo'))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def _get_on_message(self, l):
        if l.qos is None:
            l.qos = Mock()
        l.event_dispatcher = Mock()
        l.task_consumer = Mock()
        l.connection = Mock()
        l.connection.drain_events.side_effect = SystemExit()

        with self.assertRaises(SystemExit):
            l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.register_callback.called)
        return l.task_consumer.register_callback.call_args[0][0]

    def test_receieve_message(self):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.event_dispatcher = Mock()
        m = create_message(Mock(), task=self.foo_task.name,
                           args=[2, 4, 8], kwargs={})
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m.decode(), m)

        in_bucket = self.buffer.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, self.foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_channel_error(self):

        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.buffer.put, timer=self.timer,
                         send_events=False, pool=BasePool(), app=self.app)
        l.channel_errors = (KeyError, )
        with self.assertRaises(KeyError):
            l.start()
        l.timer.stop()

    def test_start_connection_error(self):

        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.buffer.put, timer=self.timer,
                         send_events=False, pool=BasePool(), app=self.app)

        l.connection_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.timer.stop()

    def test_loop_ignores_socket_timeout(self):

        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer.qos, 10)
        l.loop(*l.loop_args())

    def test_loop_when_socket_error(self):

        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        with self.assertRaises(socket.error):
            l.loop(*l.loop_args())

        l.blueprint.state = CLOSE
        l.connection = c
        l.loop(*l.loop_args())

    def test_loop(self):

        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)

        l.loop(*l.loop_args())
        l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        self.assertEqual(l.qos.value, 10)
        l.qos.decrement_eventually()
        self.assertEqual(l.qos.value, 9)
        l.qos.update()
        self.assertEqual(l.qos.value, 9)
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_ignore_errors(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection_errors = (AttributeError, KeyError, )
        l.channel_errors = (SyntaxError, )
        ignore_errors(l, Mock(side_effect=AttributeError('foo')))
        ignore_errors(l, Mock(side_effect=KeyError('foo')))
        ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
        with self.assertRaises(IndexError):
            ignore_errors(l, Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.buffer.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        m = create_message(
            Mock(), task=self.foo_task.name,
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
            args=[2, 4, 8], kwargs={},
        )

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 1)
        current_pcount = l.qos.value
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m.decode(), m)
        l.timer.stop()
        l.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == self.foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertGreater(l.qos.value, current_pcount)
        l.timer.stop()

    def test_pidbox_callback(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        con.reset = Mock()

        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = KeyError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = ValueError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')
        self.assertTrue(con.reset.called)

    def test_revoke(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        id = uuid()
        t = create_message(backend, task=self.foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        callback = self._get_on_message(l)
        callback(t.decode(), t)
        self.assertTrue(self.buffer.empty())

    def test_receieve_message_not_registered(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        callback = self._get_on_message(l)
        self.assertFalse(callback(m.decode(), m))
        with self.assertRaises(Empty):
            self.buffer.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch('celery.worker.consumer.warn')
    @patch('celery.worker.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        callback = self._get_on_message(l)
        self.assertFalse(callback(m.decode(), m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.buffer.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject.assert_called_with()
        self.assertTrue(logger.critical.call_count)

    def test_receive_message_eta(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.steps.pop()
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(
            backend, task=self.foo_task.name,
            args=[2, 4, 8], kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )

        try:
            l.blueprint.start(l)
            p = l.app.conf.BROKER_CONNECTION_RETRY
            l.app.conf.BROKER_CONNECTION_RETRY = False
            l.blueprint.start(l)
            l.app.conf.BROKER_CONNECTION_RETRY = p
            l.blueprint.restart(l)
            l.event_dispatcher = Mock()
            callback = self._get_on_message(l)
            callback(m.decode(), m)
        finally:
            l.timer.stop()
            l.timer.join()

        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, self.foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.buffer.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        chan = con.node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error('foo')
        l.connection_errors = (socket.error, )
        con.reset()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        from celery.worker.pidbox import gPidbox
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
                            app=self.app)
        con = find_step(l, consumer.Control)
        self.assertIsInstance(con.box, gPidbox)
        con.start(l)
        l.pool.spawn_n.assert_called_with(
            con.box.loop, l,
        )

    def test__green_pidbox_node(self):
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
                            app=self.app)
        l.node = Mock()
        controller = find_step(l, consumer.Control)

        class BConsumer(Mock):

            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        controller.box.node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                controller.box._node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l.connect = lambda: Connection(obj=l)
        controller = find_step(l, consumer.Control)
        controller.box.loop(l)

        self.assertTrue(controller.box.node.listen.called)
        self.assertTrue(controller.box.consumer)
        controller.box.consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.sleep')
    def test_connect_errback(self, sleep, connect):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        from kombu.transport.memory import Transport
        Transport.connection_errors = (ChannelError, )

        def effect():
            if connect.call_count > 1:
                return
            raise ChannelError()
        connect.side_effect = effect
        l.connect()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        cont = find_step(l, consumer.Control)
        cont._node_stopped = Event()
        cont._node_shutdown = Event()
        cont._node_stopped.set()
        cont.stop(l)

    def test_start__loop(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError('foo')

        init_callback = Mock()
        l = _Consumer(self.buffer.put, timer=self.timer,
                      init_callback=init_callback, app=self.app)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = Connection()
        l.iterations = 0

        def raises_KeyError(*args, **kwargs):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError('foo')

        l.loop = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertEqual(l.iterations, 2)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.buffer.put, timer=self.timer, app=self.app,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = Connection()
        l.loop = Mock(side_effect=socket.error('foo'))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(l.loop.call_count)

    def test_reset_connection_with_no_node(self):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.steps.pop()
        self.assertEqual(None, l.pool)
        l.blueprint.start(l)
Esempio n. 8
0
class test_Consumer(unittest.TestCase):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = current_app.log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.qos = QoS(l.task_consumer, 10, l.logger)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = current_app.broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close_connection=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        backend = Mock()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        with catch_warnings(record=True) as log:
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

    @patch("celery.utils.timer2.to_timestamp")
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        m = create_message(Mock(), task=foo_task.name, args=("2, 2"), kwargs={}, eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertTrue(m.acknowledged)
        self.assertTrue(to_timestamp.call_count)

    def test_receive_message_InvalidTaskError(self):
        logger = Mock()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger, send_events=False)
        m = create_message(Mock(), task=foo_task.name, args=(1, 2), kwargs="foobarbaz", id=1)
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Received invalid task message", logger.error.call_args[0][0])

    def test_on_decode_error(self):
        logger = Mock()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger, send_events=False)

        class MockMessage(Mock):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", logger.critical.call_args[0][0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        m = create_message(Mock(), task=foo_task.name, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_start_connection_error(self):
        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False, pool=BasePool())
        l.connection_errors = (KeyError,)
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()
        l.priority_timer.stop()

    def test_consume_messages(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10, l.logger)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        l.qos.decrement()
        l.consume_messages()
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_maybe_conn_error(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.connection_errors = (KeyError,)
        l.channel_errors = (SyntaxError,)
        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
        self.assertRaises(IndexError, l.maybe_conn_error, Mock(side_effect=IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.qos = QoS(None, 10, l.logger)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        m = create_message(Mock(), task=foo_task.name, eta=datetime.now().isoformat(), args=[2, 4, 8], kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.qos.call_count)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger, send_events=False)
        backend = Mock()
        id = gen_unique_id()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8], kwargs={}, id=id)
        from celery.worker.state import revoked

        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        backend = Mock()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(
            backend, task=foo_task.name, args=[2, 4, 8], kwargs={}, eta=(datetime.now() + timedelta(days=1)).isoformat()
        )

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        init_callback = Mock()
        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False, init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(init_callback.call_count)
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = BrokerConnection()
        l.consume_messages = Mock(side_effect=socket.error("foo"))
        self.assertRaises(socket.error, l.start)
        self.assertTrue(init_callback.call_count)
        self.assertTrue(l.consume_messages.call_count)

    def test_reset_connection_with_no_node(self):

        l = MainConsumer(self.ready_queue, self.eta_schedule, self.logger)
        self.assertEqual(None, l.pool)
        l.reset_connection()
Esempio n. 9
0
class test_Consumer(ConsumerCase):

    def setup(self):
        self.buffer = FastQueue()
        self.timer = Timer()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z
        self.foo_task = foo_task

    def teardown(self):
        self.timer.stop()

    def LoopConsumer(self, buffer=None, controller=None, timer=None, app=None,
                     without_mingle=True, without_gossip=True,
                     without_heartbeat=True, **kwargs):
        if controller is None:
            controller = Mock(name='.controller')
        buffer = buffer if buffer is not None else self.buffer.put
        timer = timer if timer is not None else self.timer
        app = app if app is not None else self.app
        c = Consumer(
            buffer,
            timer=timer,
            app=app,
            controller=controller,
            without_mingle=without_mingle,
            without_gossip=without_gossip,
            without_heartbeat=without_heartbeat,
            **kwargs
        )
        c.task_consumer = Mock(name='.task_consumer')
        c.qos = QoS(c.task_consumer.qos, 10)
        c.connection = Mock(name='.connection')
        c.controller = c.app.WorkController()
        c.heart = Mock(name='.heart')
        c.controller.consumer = c
        c.pool = c.controller.pool = Mock(name='.controller.pool')
        c.node = Mock(name='.node')
        c.event_dispatcher = mock_event_dispatcher()
        return c

    def NoopConsumer(self, *args, **kwargs):
        c = self.LoopConsumer(*args, **kwargs)
        c.loop = Mock(name='.loop')
        return c

    def test_info(self):
        c = self.NoopConsumer()
        c.connection.info.return_value = {'foo': 'bar'}
        c.controller.pool.info.return_value = [Mock(), Mock()]
        info = c.controller.stats()
        assert info['prefetch_count'] == 10
        assert info['broker']

    def test_start_when_closed(self):
        c = self.NoopConsumer()
        c.blueprint.state = CLOSE
        c.start()

    def test_connection(self):
        c = self.NoopConsumer()

        c.blueprint.start(c)
        assert isinstance(c.connection, Connection)

        c.blueprint.state = RUN
        c.event_dispatcher = None
        c.blueprint.restart(c)
        assert c.connection

        c.blueprint.state = RUN
        c.shutdown()
        assert c.connection is None
        assert c.task_consumer is None

        c.blueprint.start(c)
        assert isinstance(c.connection, Connection)
        c.blueprint.restart(c)

        c.stop()
        c.shutdown()
        assert c.connection is None
        assert c.task_consumer is None

    def test_close_connection(self):
        c = self.NoopConsumer()
        c.blueprint.state = RUN
        step = find_step(c, consumer.Connection)
        connection = c.connection
        step.shutdown(c)
        connection.close.assert_called()
        assert c.connection is None

    def test_close_connection__heart_shutdown(self):
        c = self.NoopConsumer()
        event_dispatcher = c.event_dispatcher
        heart = c.heart
        c.event_dispatcher.enabled = True
        c.blueprint.state = RUN
        Events = find_step(c, consumer.Events)
        Events.shutdown(c)
        Heart = find_step(c, consumer.Heart)
        Heart.shutdown(c)
        event_dispatcher.close.assert_called()
        heart.stop.assert_called_with()

    @patch('celery.worker.consumer.consumer.warn')
    def test_receive_message_unknown(self, warn):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='.channeol')
        m = create_message(channel, unknown={'baz': '!!!'})

        callback = self._get_on_message(c)
        callback(m)
        warn.assert_called()

    @patch('celery.worker.strategy.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=('2, 2'), kwargs={},
            eta=datetime.now().isoformat(),
        )
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)
        assert m.acknowledged

    @patch('celery.worker.consumer.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=(1, 2), kwargs='foobarbaz', id=1)
        c.update_strategies()
        strat = c.strategies[self.foo_task.name] = Mock(name='strategy')
        strat.side_effect = InvalidTaskError()

        callback = self._get_on_message(c)
        callback(m)
        error.assert_called()
        assert 'Received invalid task message' in error.call_args[0][0]

    @patch('celery.worker.consumer.consumer.crit')
    def test_on_decode_error(self, crit):
        c = self.LoopConsumer()

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        c.on_decode_error(message, KeyError('foo'))
        assert message.ack.call_count
        assert "Can't decode message body" in crit.call_args[0][0]

    def _get_on_message(self, c):
        if c.qos is None:
            c.qos = Mock()
        c.task_consumer = Mock()
        c.event_dispatcher = mock_event_dispatcher()
        c.connection = Mock(name='.connection')
        c.connection.get_heartbeat_interval.return_value = 0
        c.connection.drain_events.side_effect = WorkerShutdown()

        with pytest.raises(WorkerShutdown):
            c.loop(*c.loop_args())
        assert c.task_consumer.on_message
        return c.task_consumer.on_message

    def test_receieve_message(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            args=[2, 4, 8], kwargs={},
        )
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)

        in_bucket = self.buffer.get_nowait()
        assert isinstance(in_bucket, Request)
        assert in_bucket.name == self.foo_task.name
        assert in_bucket.execute() == 2 * 4 * 8
        assert self.timer.empty()

    def test_start_channel_error(self):
        c = self.NoopConsumer(task_events=False, pool=BasePool())
        c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
        c.channel_errors = (KeyError,)
        try:
            with pytest.raises(KeyError):
                c.start()
        finally:
            c.timer and c.timer.stop()

    def test_start_connection_error(self):
        c = self.NoopConsumer(task_events=False, pool=BasePool())
        c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
        c.connection_errors = (KeyError,)
        try:
            with pytest.raises(SyntaxError):
                c.start()
        finally:
            c.timer and c.timer.stop()

    def test_loop_ignores_socket_timeout(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        c = self.NoopConsumer()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.qos = QoS(c.task_consumer.qos, 10)
        c.loop(*c.loop_args())

    def test_loop_when_socket_error(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        c = self.LoopConsumer()
        c.blueprint.state = RUN
        conn = c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.qos = QoS(c.task_consumer.qos, 10)
        with pytest.raises(socket.error):
            c.loop(*c.loop_args())

        c.blueprint.state = CLOSE
        c.connection = conn
        c.loop(*c.loop_args())

    def test_loop(self):

        class Connection(self.app.connection_for_read().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

            @property
            def supports_heartbeats(self):
                return False

        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.obj = c
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.qos = QoS(c.task_consumer.qos, 10)

        c.loop(*c.loop_args())
        c.loop(*c.loop_args())
        assert c.task_consumer.consume.call_count
        c.task_consumer.qos.assert_called_with(prefetch_count=10)
        assert c.qos.value == 10
        c.qos.decrement_eventually()
        assert c.qos.value == 9
        c.qos.update()
        assert c.qos.value == 9
        c.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_ignore_errors(self):
        c = self.NoopConsumer()
        c.connection_errors = (AttributeError, KeyError,)
        c.channel_errors = (SyntaxError,)
        ignore_errors(c, Mock(side_effect=AttributeError('foo')))
        ignore_errors(c, Mock(side_effect=KeyError('foo')))
        ignore_errors(c, Mock(side_effect=SyntaxError('foo')))
        with pytest.raises(IndexError):
            ignore_errors(c, Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        c = self.NoopConsumer()
        c.qos = QoS(None, 10)
        task = Mock(name='task', id='1234213')
        qos = c.qos.value
        c.apply_eta_task(task)
        assert task in state.reserved_requests
        assert c.qos.value == qos - 1
        assert self.buffer.get_nowait() is task

    def test_receieve_message_eta_isoformat(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        m = self.create_task_message(
            Mock(), self.foo_task.name,
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
            args=[2, 4, 8], kwargs={},
        )

        c.qos = QoS(c.task_consumer.qos, 1)
        current_pcount = c.qos.value
        c.event_dispatcher.enabled = False
        c.update_strategies()
        callback = self._get_on_message(c)
        callback(m)
        c.timer.stop()
        c.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == self.foo_task.name:
                found = True
        assert found
        assert c.qos.value > current_pcount
        c.timer.stop()

    def test_pidbox_callback(self):
        c = self.NoopConsumer()
        con = find_step(c, consumer.Control).box
        con.node = Mock()
        con.reset = Mock()

        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = KeyError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = ValueError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')
        con.reset.assert_called()

    def test_revoke(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='channel')
        id = uuid()
        t = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={}, id=id,
        )

        state.revoked.add(id)

        callback = self._get_on_message(c)
        callback(t)
        assert self.buffer.empty()

    def test_receieve_message_not_registered(self):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        c.steps.pop()
        channel = Mock(name='channel')
        m = self.create_task_message(
            channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
        )

        callback = self._get_on_message(c)
        assert not callback(m)
        with pytest.raises(Empty):
            self.buffer.get_nowait()
        assert self.timer.empty()

    @patch('celery.worker.consumer.consumer.warn')
    @patch('celery.worker.consumer.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        c = self.LoopConsumer()
        c.blueprint.state = RUN
        channel = Mock(name='channel')
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={},
        )
        m.headers = None

        c.update_strategies()
        c.connection_errors = (socket.error,)
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        callback = self._get_on_message(c)
        assert not callback(m)
        warn.assert_called()
        with pytest.raises(Empty):
            self.buffer.get_nowait()
        assert self.timer.empty()
        m.reject_log_error.assert_called_with(logger, c.connection_errors)

    def test_receive_message_eta(self):
        if os.environ.get('C_DEBUG_TEST'):
            pp = partial(print, file=sys.__stderr__)
        else:
            def pp(*args, **kwargs):
                pass
        pp('TEST RECEIVE MESSAGE ETA')
        pp('+CREATE MYKOMBUCONSUMER')
        c = self.LoopConsumer()
        pp('-CREATE MYKOMBUCONSUMER')
        c.steps.pop()
        channel = Mock(name='channel')
        pp('+ CREATE MESSAGE')
        m = self.create_task_message(
            channel, self.foo_task.name,
            args=[2, 4, 8], kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )
        pp('- CREATE MESSAGE')

        try:
            pp('+ BLUEPRINT START 1')
            c.blueprint.start(c)
            pp('- BLUEPRINT START 1')
            p = c.app.conf.broker_connection_retry
            c.app.conf.broker_connection_retry = False
            pp('+ BLUEPRINT START 2')
            c.blueprint.start(c)
            pp('- BLUEPRINT START 2')
            c.app.conf.broker_connection_retry = p
            pp('+ BLUEPRINT RESTART')
            c.blueprint.restart(c)
            pp('- BLUEPRINT RESTART')
            pp('+ GET ON MESSAGE')
            callback = self._get_on_message(c)
            pp('- GET ON MESSAGE')
            pp('+ CALLBACK')
            callback(m)
            pp('- CALLBACK')
        finally:
            pp('+ STOP TIMER')
            c.timer.stop()
            pp('- STOP TIMER')
            try:
                pp('+ JOIN TIMER')
                c.timer.join()
                pp('- JOIN TIMER')
            except RuntimeError:
                pass

        in_hold = c.timer.queue[0]
        assert len(in_hold) == 3
        eta, priority, entry = in_hold
        task = entry.args[0]
        assert isinstance(task, Request)
        assert task.name == self.foo_task.name
        assert task.execute() == 2 * 4 * 8
        with pytest.raises(Empty):
            self.buffer.get_nowait()

    def test_reset_pidbox_node(self):
        c = self.NoopConsumer()
        con = find_step(c, consumer.Control).box
        con.node = Mock()
        chan = con.node.channel = Mock()
        chan.close.side_effect = socket.error('foo')
        c.connection_errors = (socket.error,)
        con.reset()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        c = self.NoopConsumer(pool=Mock(is_green=True))
        con = find_step(c, consumer.Control)
        assert isinstance(con.box, gPidbox)
        con.start(c)
        c.pool.spawn_n.assert_called_with(con.box.loop, c)

    def test_green_pidbox_node(self):
        pool = Mock()
        pool.is_green = True
        c = self.NoopConsumer(pool=Mock(is_green=True))
        controller = find_step(c, consumer.Control)

        class BConsumer(Mock):

            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        controller.box.node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                controller.box._node_shutdown.set()

            def close(self):
                self.closed = True

        c.connection_for_read = lambda: Connection(obj=c)
        controller = find_step(c, consumer.Control)
        controller.box.loop(c)

        controller.box.node.listen.assert_called()
        assert controller.box.consumer
        controller.box.consumer.consume.assert_called_with()

        assert c.connection is None
        assert connections[0].closed

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.functional.sleep')
    def test_connect_errback(self, sleep, connect):
        c = self.NoopConsumer()
        Transport.connection_errors = (ChannelError,)
        connect.on_nth_call_do(ChannelError('error'), n=1)
        c.connect()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        c = self.NoopConsumer()
        cont = find_step(c, consumer.Control)
        cont._node_stopped = Event()
        cont._node_shutdown = Event()
        cont._node_stopped.set()
        cont.stop(c)

    def test_start__loop(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        init_callback = Mock(name='init_callback')
        c = self.NoopConsumer(init_callback=init_callback)
        c.qos = _QoS()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.iterations = 0

        def raises_KeyError(*args, **kwargs):
            c.iterations += 1
            if c.qos.prev != c.qos.value:
                c.qos.update()
            if c.iterations >= 2:
                raise KeyError('foo')

        c.loop = raises_KeyError
        with pytest.raises(KeyError):
            c.start()
        assert c.iterations == 2
        assert c.qos.prev == c.qos.value

        init_callback.reset_mock()
        c = self.NoopConsumer(task_events=False, init_callback=init_callback)
        c.qos = _QoS()
        c.connection = Connection(self.app.conf.broker_url)
        c.connection.get_heartbeat_interval = Mock(return_value=None)
        c.loop = Mock(side_effect=socket.error('foo'))
        with pytest.raises(socket.error):
            c.start()
        c.loop.assert_called()

    def test_reset_connection_with_no_node(self):
        c = self.NoopConsumer()
        c.steps.pop()
        c.blueprint.start(c)
Esempio n. 10
0
class test_Consumer(unittest.TestCase):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = app_or_default().log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        l.qos = QoS(l.task_consumer, 10, l.logger)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = app_or_default().broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_eta_OverflowError(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        backend = MockBackend()
        called = [False]

        def to_timestamp(d):
            called[0] = True
            raise OverflowError()

        m = create_message(backend, task=foo_task.name,
                                    args=("2, 2"),
                                    kwargs={},
                                    eta=datetime.now().isoformat())
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        from celery.worker import consumer
        prev, consumer.to_timestamp = consumer.to_timestamp, to_timestamp
        try:
            l.receive_message(m.decode(), m)
            self.assertTrue(m.acknowledged)
            self.assertTrue(called[0])
        finally:
            consumer.to_timestamp = prev

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
            args=(1, 2), kwargs="foobarbaz", id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Can't decode message body", logger.logged[0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_start_connection_error(self):

        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        l.connection_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()

    def test_consume_messages(self):
        app = app_or_default()

        class Connection(app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        class Consumer(object):
            consuming = False
            prefetch_count = 0

            def consume(self):
                self.consuming = True

            def qos(self, prefetch_size=0, prefetch_count=0,
                            apply_global=False):
                self.prefetch_count = prefetch_count

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Consumer()
        l.broadcast_consumer = Consumer()
        l.qos = QoS(l.task_consumer, 10, l.logger)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consuming)
        self.assertTrue(l.broadcast_consumer.consuming)
        self.assertEqual(l.task_consumer.prefetch_count, 10)

        l.qos.decrement()
        l.consume_messages()
        self.assertEqual(l.task_consumer.prefetch_count, 9)

    def test_maybe_conn_error(self):

        def raises(error):

            def fun():
                raise error

            return fun

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(raises(AttributeError("foo")))
        l.maybe_conn_error(raises(KeyError("foo")))
        l.maybe_conn_error(raises(SyntaxError("foo")))
        self.assertRaises(IndexError, l.maybe_conn_error,
                raises(IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        l.qos = QoS(None, 10, l.logger)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):

        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        l.event_dispatcher = MockEventDispatcher()
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={},
                           eta=(datetime.now() +
                               timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        called_back = [False]

        def init_callback(consumer):
            called_back[0] = True

        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.task_consumer = MockConsumer()
        l.broadcast_consumer = MockConsumer()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = MockConsumer()
        l.broadcast_consumer = MockConsumer()
        l.connection = BrokerConnection()

        def raises_socket_error(limit=None):
            l.iterations = 1
            raise socket.error("foo")

        l.consume_messages = raises_socket_error
        self.assertRaises(socket.error, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Esempio n. 11
0
class test_Consumer(unittest.TestCase):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = current_app.log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.qos = QoS(l.task_consumer, 10, l.logger)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = current_app.broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_eta_OverflowError(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        called = [False]

        def to_timestamp(d):
            called[0] = True
            raise OverflowError()

        m = create_message(backend,
                           task=foo_task.name,
                           args=("2, 2"),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        prev, timer2.to_timestamp = timer2.to_timestamp, to_timestamp
        try:
            l.receive_message(m.decode(), m)
            self.assertTrue(m.acknowledged)
            self.assertTrue(called[0])
        finally:
            timer2.to_timestamp = prev

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=(1, 2),
                           kwargs="foobarbaz",
                           id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            logger,
                            send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Can't decode message body", logger.logged[0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_start_connection_error(self):
        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue,
                         self.eta_schedule,
                         self.logger,
                         send_events=False,
                         pool=BasePool())
        l.connection_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()

    def test_consume_messages(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        class Consumer(object):
            consuming = False
            prefetch_count = 0

            def consume(self):
                self.consuming = True

            def qos(self,
                    prefetch_size=0,
                    prefetch_count=0,
                    apply_global=False):
                self.prefetch_count = prefetch_count

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Consumer()
        l.qos = QoS(l.task_consumer, 10, l.logger)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consuming)
        self.assertEqual(l.task_consumer.prefetch_count, 10)

        l.qos.decrement()
        l.consume_messages()
        self.assertEqual(l.task_consumer.prefetch_count, 9)

    def test_maybe_conn_error(self):
        def raises(error):
            def fun():
                raise error

            return fun

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(raises(AttributeError("foo")))
        l.maybe_conn_error(raises(KeyError("foo")))
        l.maybe_conn_error(raises(SyntaxError("foo")))
        self.assertRaises(IndexError, l.maybe_conn_error,
                          raises(IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.qos = QoS(None, 10, l.logger)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8],
                           kwargs={})

        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        t = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.event_dispatcher = MockEventDispatcher()
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           eta=(datetime.now() +
                                timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        called_back = [False]

        def init_callback(consumer):
            called_back[0] = True

        l = _Consumer(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.task_consumer = MockConsumer()
        l.broadcast_consumer = MockConsumer()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        l = _Consumer(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = MockConsumer()
        l.broadcast_consumer = MockConsumer()
        l.connection = BrokerConnection()

        def raises_socket_error(limit=None):
            l.iterations = 1
            raise socket.error("foo")

        l.consume_messages = raises_socket_error
        self.assertRaises(socket.error, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Esempio n. 12
0
class test_CarrotListener(unittest.TestCase):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_mainloop(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)

        class MockConnection(object):
            def drain_events(self):
                return "draining"

        l.connection = MockConnection()
        l.connection.connection = MockConnection()

        it = l._mainloop()
        self.assertTrue(it.next(), "draining")
        records = {}

        def create_recorder(key):
            def _recorder(*args, **kwargs):
                records[key] = True

            return _recorder

        l.task_consumer = PlaceHolder()
        l.task_consumer.iterconsume = create_recorder("consume_tasks")
        l.broadcast_consumer = PlaceHolder()
        l.broadcast_consumer.register_callback = create_recorder(
            "broadcast_callback")
        l.broadcast_consumer.iterconsume = create_recorder("consume_broadcast")
        l.task_consumer.add_consumer = create_recorder("consumer_add")

        records.clear()
        self.assertEqual(l._detect_wait_method(), l._mainloop)
        for record in ("broadcast_callback", "consume_broadcast",
                       "consume_tasks"):
            self.assertTrue(records.get(record))

        records.clear()
        l.connection.connection = PlaceHolder()
        self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
        self.assertTrue(records.get("consumer_add"))

    def test_connection(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_receive_message_control_command(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend, control={"command": "shutdown"})
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()
        l.receive_message(m.decode(), m)
        self.assertIn("shutdown", l.control_dispatch.commands)

    def test_close_connection(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_eta_OverflowError(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        backend = MockBackend()
        called = [False]

        def to_timestamp(d):
            called[0] = True
            raise OverflowError()

        m = create_message(backend,
                           task=foo_task.name,
                           args=("2, 2"),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        prev, listener.to_timestamp = listener.to_timestamp, to_timestamp
        try:
            l.receive_message(m.decode(), m)
            self.assertTrue(m.acknowledged)
            self.assertTrue(called[0])
        finally:
            listener.to_timestamp = prev

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=(1, 2),
                           kwargs="foobarbaz",
                           id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             logger,
                             send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Message decoding error", logger.logged[0])

    def test_receieve_message(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta_isoformat(self):
        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8],
                           kwargs={})

        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyCarrotListener(ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        c = create_message(backend,
                           control={
                               "command": "revoke",
                               "task_id": id
                           })
        t = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(c.decode(), c)
        from celery.worker.state import revoked
        self.assertIn(id, revoked)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyCarrotListener(self.ready_queue,
                             self.eta_schedule,
                             self.logger,
                             send_events=False)
        dispatcher = l.event_dispatcher = MockEventDispatcher()
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           eta=(datetime.now() +
                                timedelta(days=1)).isoformat())

        l.reset_connection()
        p, conf.BROKER_CONNECTION_RETRY = conf.BROKER_CONNECTION_RETRY, False
        try:
            l.reset_connection()
        finally:
            conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        self.assertTrue(dispatcher.flushed)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            next = 4

            def update(self):
                self.prev = self.next

        class _Listener(MyCarrotListener):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

            def _detect_wait_method(self):
                return self.wait_method

        called_back = [False]

        def init_callback(listener):
            called_back[0] = True

        l = _Listener(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()

        def raises_KeyError(limit=None):
            yield True
            l.iterations = 1
            raise KeyError("foo")

        l.wait_method = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.next)

        l = _Listener(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()

        def raises_socket_error(limit=None):
            yield True
            l.iterations = 1
            raise socket.error("foo")

        l.wait_method = raises_socket_error
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Esempio n. 13
0
class test_CarrotListener(unittest.TestCase):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = get_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_mainloop(self):
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)

        class MockConnection(object):

            def drain_events(self):
                return "draining"

        l.connection = MockConnection()
        l.connection.connection = MockConnection()

        it = l._mainloop()
        self.assertTrue(it.next(), "draining")
        records = {}

        def create_recorder(key):
            def _recorder(*args, **kwargs):
                records[key] = True
            return _recorder

        l.task_consumer = PlaceHolder()
        l.task_consumer.iterconsume = create_recorder("consume_tasks")
        l.broadcast_consumer = PlaceHolder()
        l.broadcast_consumer.register_callback = create_recorder(
                                                    "broadcast_callback")
        l.broadcast_consumer.iterconsume = create_recorder(
                                             "consume_broadcast")
        l.task_consumer.add_consumer = create_recorder("consumer_add")

        records.clear()
        self.assertEqual(l._detect_wait_method(), l._mainloop)
        for record in ("broadcast_callback", "consume_broadcast",
                "consume_tasks"):
            self.assertTrue(records.get(record))

        records.clear()
        l.connection.connection = PlaceHolder()
        self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
        self.assertTrue(records.get("consumer_add"))

    def test_connection(self):
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_receive_message_control_command(self):
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, control={"command": "shutdown"})
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()
        l.receive_message(m.decode(), m)
        self.assertIn("shutdown", l.control_dispatch.commands)

    def test_close_connection(self):
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
            args=(1, 2), kwargs="foobarbaz", id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Message decoding error", logger.logged[0])

    def test_receieve_message(self):
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta_isoformat(self):

        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyCarrotListener(ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        c = create_message(backend, control={"command": "revoke",
                                             "task_id": id})
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(c.decode(), c)
        from celery.worker.state import revoked
        self.assertIn(id, revoked)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        dispatcher = l.event_dispatcher = MockEventDispatcher()
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={},
                           eta=(datetime.now() +
                               timedelta(days=1)).isoformat())

        l.reset_connection()
        p, conf.BROKER_CONNECTION_RETRY = conf.BROKER_CONNECTION_RETRY, False
        try:
            l.reset_connection()
        finally:
            conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        self.assertTrue(dispatcher.flushed)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):

        class _QoS(object):
            prev = 3
            next = 4

            def update(self):
                self.prev = self.next

        class _Listener(MyCarrotListener):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

            def _detect_wait_method(self):
                return self.wait_method

        called_back = [False]

        def init_callback(listener):
            called_back[0] = True

        l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()

        def raises_KeyError(limit=None):
            yield True
            l.iterations = 1
            raise KeyError("foo")

        l.wait_method = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.next)

        l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()

        def raises_socket_error(limit=None):
            yield True
            l.iterations = 1
            raise socket.error("foo")

        l.wait_method = raises_socket_error
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Esempio n. 14
0
class test_Consumer(AppCase):
    def setup(self):
        self.buffer = FastQueue()
        self.timer = Timer()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z

        self.foo_task = foo_task

    def teardown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        l.connection = Mock()
        l.connection.info.return_value = {"foo": "bar"}
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.controller.pool.info.return_value = [Mock(), Mock()]
        l.controller.consumer = l
        info = l.controller.stats()
        self.assertEqual(info["prefetch_count"], 10)
        self.assertTrue(info["broker"])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()

        l.blueprint.start(l)
        self.assertIsInstance(l.connection, Connection)

        l.blueprint.state = RUN
        l.event_dispatcher = None
        l.blueprint.restart(l)
        self.assertTrue(l.connection)

        l.blueprint.state = RUN
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.blueprint.start(l)
        self.assertIsInstance(l.connection, Connection)
        l.blueprint.restart(l)

        l.stop()
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        step = find_step(l, consumer.Connection)
        conn = l.connection = Mock()
        step.shutdown(l)
        self.assertTrue(conn.close.called)
        self.assertIsNone(l.connection)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        eventer = l.event_dispatcher = mock_event_dispatcher()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l.blueprint.state = RUN
        Events = find_step(l, consumer.Events)
        Events.shutdown(l)
        Heart = find_step(l, consumer.Heart)
        Heart.shutdown(l)
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch("celery.worker.consumer.warn")
    def test_receive_message_unknown(self, warn):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        channel = Mock()
        m = create_message(channel, unknown={"baz": "!!!"})
        l.event_dispatcher = mock_event_dispatcher()
        l.node = MockNode()

        callback = self._get_on_message(l)
        callback(m)
        self.assertTrue(warn.call_count)

    @patch("celery.worker.strategy.to_timestamp")
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.blueprint.state = RUN
        l.steps.pop()
        m = create_task_message(Mock(), self.foo_task.name, args=("2, 2"), kwargs={}, eta=datetime.now().isoformat())
        l.event_dispatcher = mock_event_dispatcher()
        l.node = MockNode()
        l.update_strategies()
        l.qos = Mock()

        callback = self._get_on_message(l)
        callback(m)
        self.assertTrue(m.acknowledged)

    @patch("celery.worker.consumer.error")
    def test_receive_message_InvalidTaskError(self, error):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.event_dispatcher = mock_event_dispatcher()
        l.steps.pop()
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        m = create_task_message(Mock(), self.foo_task.name, args=(1, 2), kwargs="foobarbaz", id=1)
        l.update_strategies()
        l.event_dispatcher = mock_event_dispatcher()
        strat = l.strategies[self.foo_task.name] = Mock(name="strategy")
        strat.side_effect = InvalidTaskError()

        callback = self._get_on_message(l)
        callback(m)
        self.assertTrue(error.called)
        self.assertIn("Received invalid task message", error.call_args[0][0])

    @patch("celery.worker.consumer.crit")
    def test_on_decode_error(self, crit):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)

        class MockMessage(Mock):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def _get_on_message(self, l):
        if l.qos is None:
            l.qos = Mock()
        l.event_dispatcher = mock_event_dispatcher()
        l.task_consumer = Mock()
        l.connection = Mock()
        l.connection.drain_events.side_effect = WorkerShutdown()

        with self.assertRaises(WorkerShutdown):
            l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.on_message)
        return l.task_consumer.on_message

    def test_receieve_message(self):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.blueprint.state = RUN
        l.event_dispatcher = mock_event_dispatcher()
        m = create_task_message(Mock(), self.foo_task.name, args=[2, 4, 8], kwargs={})
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m)

        in_bucket = self.buffer.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, self.foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_channel_error(self):
        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.buffer.put, timer=self.timer, send_events=False, pool=BasePool(), app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.channel_errors = (KeyError,)
        with self.assertRaises(KeyError):
            l.start()
        l.timer.stop()

    def test_start_connection_error(self):
        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.buffer.put, timer=self.timer, send_events=False, pool=BasePool(), app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()

        l.connection_errors = (KeyError,)
        self.assertRaises(SyntaxError, l.start)
        l.timer.stop()

    def test_loop_ignores_socket_timeout(self):
        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer.qos, 10)
        l.loop(*l.loop_args())

    def test_loop_when_socket_error(self):
        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error("foo")

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        with self.assertRaises(socket.error):
            l.loop(*l.loop_args())

        l.blueprint.state = CLOSE
        l.connection = c
        l.loop(*l.loop_args())

    def test_loop(self):
        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)

        l.loop(*l.loop_args())
        l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        self.assertEqual(l.qos.value, 10)
        l.qos.decrement_eventually()
        self.assertEqual(l.qos.value, 9)
        l.qos.update()
        self.assertEqual(l.qos.value, 9)
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_ignore_errors(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection_errors = (AttributeError, KeyError)
        l.channel_errors = (SyntaxError,)
        ignore_errors(l, Mock(side_effect=AttributeError("foo")))
        ignore_errors(l, Mock(side_effect=KeyError("foo")))
        ignore_errors(l, Mock(side_effect=SyntaxError("foo")))
        with self.assertRaises(IndexError):
            ignore_errors(l, Mock(side_effect=IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.buffer.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.blueprint.state = RUN
        l.steps.pop()
        m = create_task_message(
            Mock(), self.foo_task.name, eta=(datetime.now() + timedelta(days=1)).isoformat(), args=[2, 4, 8], kwargs={}
        )

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 1)
        current_pcount = l.qos.value
        l.event_dispatcher = mock_event_dispatcher()
        l.enabled = False
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m)
        l.timer.stop()
        l.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == self.foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertGreater(l.qos.value, current_pcount)
        l.timer.stop()

    def test_pidbox_callback(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        con.reset = Mock()

        con.on_message("foo", "bar")
        con.node.handle_message.assert_called_with("foo", "bar")

        con.node = Mock()
        con.node.handle_message.side_effect = KeyError("foo")
        con.on_message("foo", "bar")
        con.node.handle_message.assert_called_with("foo", "bar")

        con.node = Mock()
        con.node.handle_message.side_effect = ValueError("foo")
        con.on_message("foo", "bar")
        con.node.handle_message.assert_called_with("foo", "bar")
        self.assertTrue(con.reset.called)

    def test_revoke(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        channel = Mock()
        id = uuid()
        t = create_task_message(channel, self.foo_task.name, args=[2, 4, 8], kwargs={}, id=id)
        from celery.worker.state import revoked

        revoked.add(id)

        callback = self._get_on_message(l)
        callback(t)
        self.assertTrue(self.buffer.empty())

    def test_receieve_message_not_registered(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        channel = Mock(name="channel")
        m = create_task_message(channel, "x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = mock_event_dispatcher()
        callback = self._get_on_message(l)
        self.assertFalse(callback(m))
        with self.assertRaises(Empty):
            self.buffer.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch("celery.worker.consumer.warn")
    @patch("celery.worker.consumer.logger")
    def test_receieve_message_ack_raises(self, logger, warn):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.blueprint.state = RUN
        channel = Mock()
        m = create_task_message(channel, self.foo_task.name, args=[2, 4, 8], kwargs={})
        m.headers = None

        l.event_dispatcher = mock_event_dispatcher()
        l.update_strategies()
        l.connection_errors = (socket.error,)
        m.reject = Mock()
        m.reject.side_effect = socket.error("foo")
        callback = self._get_on_message(l)
        self.assertFalse(callback(m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.buffer.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject_log_error.assert_called_with(logger, l.connection_errors)

    def test_receive_message_eta(self):
        import sys
        from functools import partial

        if os.environ.get("C_DEBUG_TEST"):
            pp = partial(print, file=sys.__stderr__)
        else:

            def pp(*args, **kwargs):
                pass

        pp("TEST RECEIVE MESSAGE ETA")
        pp("+CREATE MYKOMBUCONSUMER")
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        pp("-CREATE MYKOMBUCONSUMER")
        l.steps.pop()
        l.event_dispatcher = mock_event_dispatcher()
        channel = Mock(name="channel")
        pp("+ CREATE MESSAGE")
        m = create_task_message(
            channel, self.foo_task.name, args=[2, 4, 8], kwargs={}, eta=(datetime.now() + timedelta(days=1)).isoformat()
        )
        pp("- CREATE MESSAGE")

        try:
            pp("+ BLUEPRINT START 1")
            l.blueprint.start(l)
            pp("- BLUEPRINT START 1")
            p = l.app.conf.BROKER_CONNECTION_RETRY
            l.app.conf.BROKER_CONNECTION_RETRY = False
            pp("+ BLUEPRINT START 2")
            l.blueprint.start(l)
            pp("- BLUEPRINT START 2")
            l.app.conf.BROKER_CONNECTION_RETRY = p
            pp("+ BLUEPRINT RESTART")
            l.blueprint.restart(l)
            pp("- BLUEPRINT RESTART")
            l.event_dispatcher = mock_event_dispatcher()
            pp("+ GET ON MESSAGE")
            callback = self._get_on_message(l)
            pp("- GET ON MESSAGE")
            pp("+ CALLBACK")
            callback(m)
            pp("- CALLBACK")
        finally:
            pp("+ STOP TIMER")
            l.timer.stop()
            pp("- STOP TIMER")
            try:
                pp("+ JOIN TIMER")
                l.timer.join()
                pp("- JOIN TIMER")
            except RuntimeError:
                pass

        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, self.foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.buffer.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        chan = con.node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error("foo")
        l.connection_errors = (socket.error,)
        con.reset()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        from celery.worker.pidbox import gPidbox

        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool, app=self.app)
        con = find_step(l, consumer.Control)
        self.assertIsInstance(con.box, gPidbox)
        con.start(l)
        l.pool.spawn_n.assert_called_with(con.box.loop, l)

    def test__green_pidbox_node(self):
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool, app=self.app)
        l.node = Mock()
        controller = find_step(l, consumer.Control)

        class BConsumer(Mock):
            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        controller.box.node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return "dummy://"

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                controller.box._node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l.connect = lambda: Connection(obj=l)
        controller = find_step(l, consumer.Control)
        controller.box.loop(l)

        self.assertTrue(controller.box.node.listen.called)
        self.assertTrue(controller.box.consumer)
        controller.box.consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch("kombu.connection.Connection._establish_connection")
    @patch("kombu.utils.sleep")
    def test_connect_errback(self, sleep, connect):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        from kombu.transport.memory import Transport

        Transport.connection_errors = (ChannelError,)

        def effect():
            if connect.call_count > 1:
                return
            raise ChannelError("error")

        connect.side_effect = effect
        l.connect()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        cont = find_step(l, consumer.Control)
        cont._node_stopped = Event()
        cont._node_shutdown = Event()
        cont._node_stopped.set()
        cont.stop(l)

    def test_start__loop(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        init_callback = Mock()
        l = _Consumer(self.buffer.put, timer=self.timer, init_callback=init_callback, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = Connection()
        l.iterations = 0

        def raises_KeyError(*args, **kwargs):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.loop = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertEqual(l.iterations, 2)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.buffer.put, timer=self.timer, app=self.app, send_events=False, init_callback=init_callback)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = Connection()
        l.loop = Mock(side_effect=socket.error("foo"))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(l.loop.call_count)

    def test_reset_connection_with_no_node(self):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.controller = l.app.WorkController()
        l.pool = l.controller.pool = Mock()
        l.steps.pop()
        l.blueprint.start(l)