Пример #1
0
class test_Consumer(Case):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.timer = Timer()

    def tearDown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(l.task_consumer, 10)
        info = l.info
        self.assertEqual(info['prefetch_count'], 10)
        self.assertFalse(info['broker'])

        l.connection = current_app.connection()
        info = l.info
        self.assertTrue(info['broker'])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        l.reset_connection()
        self.assertIsInstance(l.connection, Connection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close_connection=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, Connection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch('celery.worker.consumer.warn')
    def test_receive_message_unknown(self, warn):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, unknown={'baz': '!!!'})
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertTrue(warn.call_count)

    @patch('celery.worker.consumer.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=('2, 2'),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()
        l.update_strategies()
        l.qos = Mock()

        l.receive_message(m.decode(), m)
        self.assertTrue(to_timestamp.called)
        self.assertTrue(m.acknowledged)

    @patch('celery.worker.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=(1, 2),
                           kwargs='foobarbaz',
                           id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn('Received invalid task message', error.call_args[0][0])

    @patch('celery.worker.consumer.crit')
    def test_on_decode_error(self, crit):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        l.on_decode_error(message, KeyError('foo'))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})
        l.update_strategies()

        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_connection_error(self):
        class MockConsumer(BlockingConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.ready_queue,
                         timer=self.timer,
                         send_events=False,
                         pool=BasePool())
        l.connection_errors = (KeyError, )
        with self.assertRaises(SyntaxError):
            l.start()
        l.heart.stop()
        l.timer.stop()

    def test_start_channel_error(self):
        # Regression test for AMQPChannelExceptions that can occur within the
        # consumer. (i.e. 404 errors)

        class MockConsumer(BlockingConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.ready_queue,
                         timer=self.timer,
                         send_events=False,
                         pool=BasePool())

        l.channel_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()
        l.timer.stop()

    def test_consume_messages_ignores_socket_timeout(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer, 10)
        l.consume_messages()

    def test_consume_messages_when_socket_error(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10)
        with self.assertRaises(socket.error):
            l.consume_messages()

        l._state = CLOSE
        l.connection = c
        l.consume_messages()

    def test_consume_messages(self):
        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        l.task_consumer.qos = Mock()
        self.assertEqual(l.qos.value, 10)
        l.qos.decrement_eventually()
        self.assertEqual(l.qos.value, 9)
        l.qos.update()
        self.assertEqual(l.qos.value, 9)
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_maybe_conn_error(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(Mock(side_effect=AttributeError('foo')))
        l.maybe_conn_error(Mock(side_effect=KeyError('foo')))
        l.maybe_conn_error(Mock(side_effect=SyntaxError('foo')))
        with self.assertRaises(IndexError):
            l.maybe_conn_error(Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        if sys.version_info < (2, 6):
            raise SkipTest('test broken on Python 2.5')
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(),
                           task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8],
                           kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
        current_pcount = l.qos.value
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        l.receive_message(m.decode(), m)
        l.timer.stop()
        l.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertGreater(l.qos.value, current_pcount)
        l.timer.stop()

    def test_on_control(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()
        l.reset_pidbox_node = Mock()

        l.on_control('foo', 'bar')
        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = KeyError('foo')
        l.on_control('foo', 'bar')
        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = ValueError('foo')
        l.on_control('foo', 'bar')
        l.pidbox_node.handle_message.assert_called_with('foo', 'bar')
        l.reset_pidbox_node.assert_called_with()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, timer=self.timer)
        backend = Mock()
        id = uuid()
        t = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        self.assertFalse(l.receive_message(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch('celery.worker.consumer.warn')
    @patch('celery.worker.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject.assert_called_with()
        self.assertTrue(logger.critical.call_count)

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(
            backend,
            task=foo_task.name,
            args=[2, 4, 8],
            kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.timer.stop()
        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()
        chan = l.pidbox_node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error('foo')
        l.connection_errors = (socket.error, )
        l.reset_pidbox_node()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pool = Mock()
        l.pool.is_green = True
        l.reset_pidbox_node()
        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)

    def test__green_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()

        class BConsumer(Mock):
            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        l.pidbox_node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                self.obj._pidbox_node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l._open_connection = lambda: Connection(obj=l)
        l._green_pidbox_node()

        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
        self.assertTrue(l.broadcast_consumer)
        l.broadcast_consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.sleep')
    def test_open_connection_errback(self, sleep, connect):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        from kombu.transport.memory import Transport
        Transport.connection_errors = (StdChannelError, )

        def effect():
            if connect.call_count > 1:
                return
            raise StdChannelError()

        connect.side_effect = effect
        l._open_connection()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._pidbox_node_stopped = Event()
        l._pidbox_node_shutdown = Event()
        l._pidbox_node_stopped.set()
        l.stop_pidbox_node()

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError('foo')

        init_callback = Mock()
        l = _Consumer(self.ready_queue,
                      timer=self.timer,
                      init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = Connection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError('foo')

        l.consume_messages = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue,
                      timer=self.timer,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = Connection()
        l.consume_messages = Mock(side_effect=socket.error('foo'))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertTrue(l.consume_messages.call_count)

    def test_reset_connection_with_no_node(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        self.assertEqual(None, l.pool)
        l.reset_connection()

    def test_on_task_revoked(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = True
        l.on_task(task)

    def test_on_task_no_events(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = False
        l.event_dispatcher = Mock()
        l.event_dispatcher.enabled = False
        task.eta = None
        l._does_info = False
        l.on_task(task)
Пример #2
0
class test_CarrotListener(unittest.TestCase):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = get_logger()
        self.logger.setLevel(0)

    def test_mainloop(self):
        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)

        class MockConnection(object):

            def drain_events(self):
                return "draining"

        l.connection = MockConnection()
        l.connection.connection = MockConnection()

        it = l._mainloop()
        self.assertTrue(it.next(), "draining")

        records = {}
        def create_recorder(key):
            def _recorder(*args, **kwargs):
                records[key] = True
            return _recorder

        l.task_consumer = PlaceHolder()
        l.task_consumer.iterconsume = create_recorder("consume_tasks")
        l.broadcast_consumer = PlaceHolder()
        l.broadcast_consumer.register_callback = create_recorder(
                                                    "broadcast_callback")
        l.broadcast_consumer.iterconsume = create_recorder(
                                             "consume_broadcast")
        l.task_consumer.add_consumer = create_recorder("consumer_add")

        records.clear()
        self.assertEqual(l._detect_wait_method(), l._mainloop)
        for record in ("broadcast_callback", "consume_broadcast",
                "consume_tasks"):
            self.assertTrue(records.get(record))

        records.clear()
        l.connection.connection = PlaceHolder()
        self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
        self.assertTrue(records.get("consumer_add"))

    def test_connection(self):
        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_receive_message_control_command(self):
        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, control={"command": "shutdown"})
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()
        l.receive_message(m.decode(), m)
        self.assertIn("shutdown", l.control_dispatch.commands)

    def test_close_connection(self):
        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        l._state = RUN
        l.close_connection()

        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = CarrotListener(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
            args=(1, 2), kwargs="foobarbaz", id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = CarrotListener(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Message decoding error", logger.logged[0])

    def test_receieve_message(self):
        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta_isoformat(self):

        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.receive_message(m.decode(), m)

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)

    def test_revoke(self):
        ready_queue = FastQueue()
        l = CarrotListener(ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        c = create_message(backend, control={"command": "revoke",
                                             "task_id": id})
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(c.decode(), c)
        from celery.worker.state import revoked
        self.assertIn(id, revoked)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={},
                           eta=(datetime.now() +
                               timedelta(days=1)).isoformat())

        l.reset_connection()
        p, conf.BROKER_CONNECTION_RETRY = conf.BROKER_CONNECTION_RETRY, False
        try:
            l.reset_connection()
        finally:
            conf.BROKER_CONNECTION_RETRY = p
        l.receive_message(m.decode(), m)

        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):

        class _QoS(object):
            prev = 3
            next = 4

            def update(self):
                self.prev = self.next

        class _Listener(CarrotListener):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

            def _detect_wait_method(self):
                return self.wait_method

        called_back = [False]
        def init_callback(listener):
            called_back[0] = True


        l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()

        def raises_KeyError(limit=None):
            yield True
            l.iterations = 1
            raise KeyError("foo")

        l.wait_method = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.next)

        l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        def raises_socket_error(limit=None):
            yield True
            l.iterations = 1
            raise socket.error("foo")

        l.wait_method = raises_socket_error
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Пример #3
0
class test_Consumer(unittest.TestCase):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = app_or_default().log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_eta_OverflowError(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        backend = MockBackend()
        called = [False]

        def to_timestamp(d):
            called[0] = True
            raise OverflowError()

        m = create_message(backend, task=foo_task.name,
                                    args=("2, 2"),
                                    kwargs={},
                                    eta=datetime.now().isoformat())
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        from celery.worker import consumer
        prev, consumer.to_timestamp = consumer.to_timestamp, to_timestamp
        try:
            l.receive_message(m.decode(), m)
            self.assertTrue(m.acknowledged)
            self.assertTrue(called[0])
        finally:
            consumer.to_timestamp = prev

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
            args=(1, 2), kwargs="foobarbaz", id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Message decoding error", logger.logged[0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta_isoformat(self):

        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        dispatcher = l.event_dispatcher = MockEventDispatcher()
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={},
                           eta=(datetime.now() +
                               timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        self.assertTrue(dispatcher.flushed)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):

        class _QoS(object):
            prev = 3
            next = 4

            def update(self):
                self.prev = self.next

        class _Consumer(MyKombuConsumer):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        called_back = [False]

        def init_callback(consumer):
            called_back[0] = True

        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.task_consumer = MockConsumer()
        l.qos = _QoS()
        l.connection = BrokerConnection()

        def raises_KeyError(limit=None):
            yield True
            l.iterations = 1
            raise KeyError("foo")

        l._mainloop = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.next)

        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = MockConsumer()
        l.connection = BrokerConnection()

        def raises_socket_error(limit=None):
            yield True
            l.iterations = 1
            raise socket.error("foo")

        l._mainloop = raises_socket_error
        self.assertRaises(socket.error, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Пример #4
0
class test_Consumer(Case):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.timer = Timer()

    def tearDown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        l.connection = Mock()
        l.connection.info.return_value = {'foo': 'bar'}
        l.controller = l.app.WorkController()
        l.controller.pool = Mock()
        l.controller.pool.info.return_value = [Mock(), Mock()]
        l.controller.consumer = l
        info = l.controller.stats()
        self.assertEqual(info['prefetch_count'], 10)
        self.assertTrue(info['broker'])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.namespace.state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        l.namespace.start(l)
        self.assertIsInstance(l.connection, Connection)

        l.namespace.state = RUN
        l.event_dispatcher = None
        l.namespace.restart(l)
        self.assertTrue(l.connection)

        l.namespace.state = RUN
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.namespace.start(l)
        self.assertIsInstance(l.connection, Connection)
        l.namespace.restart(l)

        l.stop()
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.namespace.state = RUN
        step = find_step(l, consumer.Connection)
        conn = l.connection = Mock()
        step.shutdown(l)
        self.assertTrue(conn.close.called)
        self.assertIsNone(l.connection)

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l.namespace.state = RUN
        Events = find_step(l, consumer.Events)
        Events.shutdown(l)
        Heart = find_step(l, consumer.Heart)
        Heart.shutdown(l)
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch('celery.worker.consumer.warn')
    def test_receive_message_unknown(self, warn):
        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.steps.pop()
        backend = Mock()
        m = create_message(backend, unknown={'baz': '!!!'})
        l.event_dispatcher = Mock()
        l.node = MockNode()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertTrue(warn.call_count)

    @patch('celery.worker.consumer.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.steps.pop()
        m = create_message(Mock(), task=foo_task.name,
                           args=('2, 2'),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.node = MockNode()
        l.update_strategies()
        l.qos = Mock()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertTrue(m.acknowledged)

    @patch('celery.worker.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.steps.pop()
        m = create_message(Mock(), task=foo_task.name,
                           args=(1, 2), kwargs='foobarbaz', id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertIn('Received invalid task message', error.call_args[0][0])

    @patch('celery.worker.consumer.crit')
    def test_on_decode_error(self, crit):
        l = Consumer(self.ready_queue, timer=self.timer)

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        l.on_decode_error(message, KeyError('foo'))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def _get_on_message(self, l):
        if l.qos is None:
            l.qos = Mock()
        l.event_dispatcher = Mock()
        l.task_consumer = Mock()
        l.connection = Mock()
        l.connection.drain_events.side_effect = SystemExit()

        with self.assertRaises(SystemExit):
            l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.register_callback.called)
        return l.task_consumer.register_callback.call_args[0][0]

    def test_receieve_message(self):
        l = Consumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                           args=[2, 4, 8], kwargs={})
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_channel_error(self):

        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.ready_queue, timer=self.timer,
                         send_events=False, pool=BasePool())
        l.channel_errors = (KeyError, )
        with self.assertRaises(KeyError):
            l.start()
        l.timer.stop()

    def test_start_connection_error(self):

        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.ready_queue, timer=self.timer,
                         send_events=False, pool=BasePool())

        l.connection_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.timer.stop()

    def test_loop_ignores_socket_timeout(self):

        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer.qos, 10)
        l.loop(*l.loop_args())

    def test_loop_when_socket_error(self):

        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        l = Consumer(self.ready_queue, timer=self.timer)
        l.namespace.state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        with self.assertRaises(socket.error):
            l.loop(*l.loop_args())

        l.namespace.state = CLOSE
        l.connection = c
        l.loop(*l.loop_args())

    def test_loop(self):

        class Connection(current_app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = Consumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)

        l.loop(*l.loop_args())
        l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        self.assertEqual(l.qos.value, 10)
        l.qos.decrement_eventually()
        self.assertEqual(l.qos.value, 9)
        l.qos.update()
        self.assertEqual(l.qos.value, 9)
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_ignore_errors(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection_errors = (AttributeError, KeyError, )
        l.channel_errors = (SyntaxError, )
        ignore_errors(l, Mock(side_effect=AttributeError('foo')))
        ignore_errors(l, Mock(side_effect=KeyError('foo')))
        ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
        with self.assertRaises(IndexError):
            ignore_errors(l, Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.steps.pop()
        m = create_message(Mock(), task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 1)
        current_pcount = l.qos.value
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m.decode(), m)
        l.timer.stop()
        l.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertGreater(l.qos.value, current_pcount)
        l.timer.stop()

    def test_pidbox_callback(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        con.reset = Mock()

        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = KeyError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = ValueError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')
        self.assertTrue(con.reset.called)

    def test_revoke(self):
        ready_queue = FastQueue()
        l = _MyKombuConsumer(ready_queue, timer=self.timer)
        l.steps.pop()
        backend = Mock()
        id = uuid()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        callback = self._get_on_message(l)
        callback(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.steps.pop()
        backend = Mock()
        m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        callback = self._get_on_message(l)
        self.assertFalse(callback(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch('celery.worker.consumer.warn')
    @patch('celery.worker.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        l = Consumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        callback = self._get_on_message(l)
        self.assertFalse(callback(m.decode(), m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject.assert_called_with()
        self.assertTrue(logger.critical.call_count)

    def test_receive_message_eta(self):
        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.steps.pop()
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(
            backend, task=foo_task.name,
            args=[2, 4, 8], kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )

        l.namespace.start(l)
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.namespace.start(l)
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.namespace.restart(l)
        l.event_dispatcher = Mock()
        callback = self._get_on_message(l)
        callback(m.decode(), m)
        l.timer.stop()
        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        chan = con.node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error('foo')
        l.connection_errors = (socket.error, )
        con.reset()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        from celery.worker.pidbox import gPidbox
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
        con = find_step(l, consumer.Control)
        self.assertIsInstance(con.box, gPidbox)
        con.start(l)
        l.pool.spawn_n.assert_called_with(
            con.box.loop, l,
        )

    def test__green_pidbox_node(self):
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
        l.node = Mock()
        controller = find_step(l, consumer.Control)

        class BConsumer(Mock):

            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        controller.box.node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                controller.box._node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l.connect = lambda: Connection(obj=l)
        controller = find_step(l, consumer.Control)
        controller.box.loop(l)

        self.assertTrue(controller.box.node.listen.called)
        self.assertTrue(controller.box.consumer)
        controller.box.consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.sleep')
    def test_connect_errback(self, sleep, connect):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        from kombu.transport.memory import Transport
        Transport.connection_errors = (StdChannelError, )

        def effect():
            if connect.call_count > 1:
                return
            raise StdChannelError()
        connect.side_effect = effect
        l.connect()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        cont = find_step(l, consumer.Control)
        cont._node_stopped = Event()
        cont._node_shutdown = Event()
        cont._node_stopped.set()
        cont.stop(l)

    def test_start__loop(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError('foo')

        init_callback = Mock()
        l = _Consumer(self.ready_queue, timer=self.timer,
                      init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = Connection()
        l.iterations = 0

        def raises_KeyError(*args, **kwargs):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError('foo')

        l.loop = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertEqual(l.iterations, 2)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue, timer=self.timer,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = Connection()
        l.loop = Mock(side_effect=socket.error('foo'))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(l.loop.call_count)

    def test_reset_connection_with_no_node(self):
        l = Consumer(self.ready_queue, timer=self.timer)
        l.steps.pop()
        self.assertEqual(None, l.pool)
        l.namespace.start(l)

    def test_on_task_revoked(self):
        l = Consumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = True
        l.on_task(task)

    def test_on_task_no_events(self):
        l = Consumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = False
        l.event_dispatcher = Mock()
        l.event_dispatcher.enabled = False
        task.eta = None
        l._does_info = False
        l.on_task(task)
Пример #5
0
class test_Consumer(Case):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.timer = Timer()

    def tearDown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(l.task_consumer, 10)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = current_app.broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close_connection=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch("celery.worker.consumer.warn")
    def test_receive_message_unknown(self, warn):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertTrue(warn.call_count)

    @patch("celery.utils.timer2.to_timestamp")
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                                   args=("2, 2"),
                                   kwargs={},
                                   eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()
        l.update_strategies()

        l.receive_message(m.decode(), m)
        self.assertTrue(m.acknowledged)
        self.assertTrue(to_timestamp.call_count)

    @patch("celery.worker.consumer.error")
    def test_receive_message_InvalidTaskError(self, error):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                           args=(1, 2), kwargs="foobarbaz", id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Received invalid task message", error.call_args[0][0])

    @patch("celery.worker.consumer.crit")
    def test_on_decode_error(self, crit):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)

        class MockMessage(Mock):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                           args=[2, 4, 8], kwargs={})
        l.update_strategies()

        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_connection_error(self):

        class MockConsumer(BlockingConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue, timer=self.timer,
                             send_events=False, pool=BasePool())
        l.connection_errors = (KeyError, )
        with self.assertRaises(SyntaxError):
            l.start()
        l.heart.stop()
        l.timer.stop()

    def test_start_channel_error(self):
        # Regression test for AMQPChannelExceptions that can occur within the
        # consumer. (i.e. 404 errors)

        class MockConsumer(BlockingConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue, timer=self.timer,
                             send_events=False, pool=BasePool())

        l.channel_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()
        l.timer.stop()

    def test_consume_messages_ignores_socket_timeout(self):

        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer, 10)
        l.consume_messages()

    def test_consume_messages_when_socket_error(self):

        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error("foo")

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10)
        with self.assertRaises(socket.error):
            l.consume_messages()

        l._state = CLOSE
        l.connection = c
        l.consume_messages()

    def test_consume_messages(self):

        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        l.qos.decrement()
        l.consume_messages()
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_maybe_conn_error(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
        with self.assertRaises(IndexError):
            l.maybe_conn_error(Mock(side_effect=IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        m = create_message(Mock(), task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count)
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        l.receive_message(m.decode(), m)
        l.timer.stop()

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.qos.call_count)
        l.timer.stop()

    def test_on_control(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()
        l.reset_pidbox_node = Mock()

        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = KeyError("foo")
        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = ValueError("foo")
        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
        l.reset_pidbox_node.assert_called_with()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, timer=self.timer)
        backend = Mock()
        id = uuid()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        self.assertFalse(l.receive_message(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch("celery.worker.consumer.warn")
    @patch("celery.worker.consumer.logger")
    def test_receieve_message_ack_raises(self, logger, warn):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        m.reject = Mock()
        m.reject.side_effect = socket.error("foo")
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject.assert_called_with()
        self.assertTrue(logger.critical.call_count)

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={},
                           eta=(datetime.now() +
                               timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.timer.stop()
        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()
        chan = l.pidbox_node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error("foo")
        l.connection_errors = (socket.error, )
        l.reset_pidbox_node()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pool = Mock()
        l.pool.is_green = True
        l.reset_pidbox_node()
        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)

    def test__green_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l.pidbox_node = Mock()

        class BConsumer(Mock):

            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        l.pidbox_node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                self.obj._pidbox_node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l._open_connection = lambda: Connection(obj=l)
        l._green_pidbox_node()

        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
        self.assertTrue(l.broadcast_consumer)
        l.broadcast_consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch("kombu.connection.BrokerConnection._establish_connection")
    @patch("kombu.utils.sleep")
    def test_open_connection_errback(self, sleep, connect):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        from kombu.transport.memory import Transport
        Transport.connection_errors = (StdChannelError, )

        def effect():
            if connect.call_count > 1:
                return
            raise StdChannelError()
        connect.side_effect = effect
        l._open_connection()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
        l._pidbox_node_stopped = Event()
        l._pidbox_node_shutdown = Event()
        l._pidbox_node_stopped.set()
        l.stop_pidbox_node()

    def test_start__consume_messages(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        init_callback = Mock()
        l = _Consumer(self.ready_queue, timer=self.timer,
                      init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue, timer=self.timer,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = BrokerConnection()
        l.consume_messages = Mock(side_effect=socket.error("foo"))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertTrue(l.consume_messages.call_count)

    def test_reset_connection_with_no_node(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        self.assertEqual(None, l.pool)
        l.reset_connection()

    def test_on_task_revoked(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = True
        l.on_task(task)

    def test_on_task_no_events(self):
        l = BlockingConsumer(self.ready_queue, timer=self.timer)
        task = Mock()
        task.revoked.return_value = False
        l.event_dispatcher = Mock()
        l.event_dispatcher.enabled = False
        task.eta = None
        l._does_info = False
        l.on_task(task)
Пример #6
0
class test_Consumer(Case):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = current_app.log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.qos = QoS(l.task_consumer, 10, l.logger)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = current_app.broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l._state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close_connection=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = Mock()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        with self.assertWarnsRegex(RuntimeWarning, r'unknown message'):
            l.receive_message(m.decode(), m)

    @patch("celery.utils.timer2.to_timestamp")
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=("2, 2"),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()
        l.update_strategies()

        l.receive_message(m.decode(), m)
        self.assertTrue(m.acknowledged)
        self.assertTrue(to_timestamp.call_count)

    def test_receive_message_InvalidTaskError(self):
        logger = Mock()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            logger,
                            send_events=False)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=(1, 2),
                           kwargs="foobarbaz",
                           id=1)
        l.update_strategies()
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Received invalid task message",
                      logger.error.call_args[0][0])

    def test_on_decode_error(self):
        logger = Mock()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            logger,
                            send_events=False)

        class MockMessage(Mock):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body",
                      logger.critical.call_args[0][0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        m = create_message(Mock(),
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})
        l.update_strategies()

        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_start_connection_error(self):
        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue,
                         self.eta_schedule,
                         self.logger,
                         send_events=False,
                         pool=BasePool())
        l.connection_errors = (KeyError, )
        with self.assertRaises(SyntaxError):
            l.start()
        l.heart.stop()
        l.priority_timer.stop()

    def test_start_channel_error(self):
        # Regression test for AMQPChannelExceptions that can occur within the
        # consumer. (i.e. 404 errors)

        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue,
                         self.eta_schedule,
                         self.logger,
                         send_events=False,
                         pool=BasePool())

        l.channel_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()
        l.priority_timer.stop()

    def test_consume_messages_ignores_socket_timeout(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer, 10, l.logger)
        l.consume_messages()

    def test_consume_messages_when_socket_error(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error("foo")

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l._state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10, l.logger)
        with self.assertRaises(socket.error):
            l.consume_messages()

        l._state = CLOSE
        l.connection = c
        l.consume_messages()

    def test_consume_messages(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10, l.logger)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        l.qos.decrement()
        l.consume_messages()
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_maybe_conn_error(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
        with self.assertRaises(IndexError):
            l.maybe_conn_error(Mock(side_effect=IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.qos = QoS(None, 10, l.logger)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        m = create_message(Mock(),
                           task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8],
                           kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = Mock()
        l.enabled = False
        l.update_strategies()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.qos.call_count)
        l.eta_schedule.stop()

    def test_on_control(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.pidbox_node = Mock()
        l.reset_pidbox_node = Mock()

        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = KeyError("foo")
        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")

        l.pidbox_node = Mock()
        l.pidbox_node.handle_message.side_effect = ValueError("foo")
        l.on_control("foo", "bar")
        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
        l.reset_pidbox_node.assert_called_with()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = Mock()
        id = uuid()
        t = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = Mock()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        self.assertFalse(l.receive_message(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_ack_raises(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.connection_errors = (socket.error, )
        l.logger = Mock()
        m.ack = Mock()
        m.ack.side_effect = socket.error("foo")
        with self.assertWarnsRegex(RuntimeWarning, r'unknown message'):
            self.assertFalse(l.receive_message(m.decode(), m))
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()
        self.assertTrue(self.eta_schedule.empty())
        m.ack.assert_called_with()
        self.assertTrue(l.logger.critical.call_count)

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           eta=(datetime.now() +
                                timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.ready_queue.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.pidbox_node = Mock()
        chan = l.pidbox_node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error("foo")
        l.connection_errors = (socket.error, )
        l.reset_pidbox_node()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.pool = Mock()
        l.pool.is_green = True
        l.reset_pidbox_node()
        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)

    def test__green_pidbox_node(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.pidbox_node = Mock()

        class BConsumer(Mock):
            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        l.pidbox_node.listen = BConsumer()
        connections = []

        class Connection(object):
            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def drain_events(self, **kwargs):
                self.obj.connection = None
                self.obj._pidbox_node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l._open_connection = lambda: Connection(obj=l)
        l._green_pidbox_node()

        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
        self.assertTrue(l.broadcast_consumer)
        l.broadcast_consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        init_callback = Mock()
        l = _Consumer(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = BrokerConnection()
        l.consume_messages = Mock(side_effect=socket.error("foo"))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(init_callback.call_count)
        self.assertTrue(l.consume_messages.call_count)

    def test_reset_connection_with_no_node(self):

        l = MainConsumer(self.ready_queue, self.eta_schedule, self.logger)
        self.assertEqual(None, l.pool)
        l.reset_connection()
Пример #7
0
class test_Consumer(unittest.TestCase):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = current_app.log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.qos = QoS(l.task_consumer, 10, l.logger)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = current_app.broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close_connection=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        eventer = l.event_dispatcher = Mock()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        backend = Mock()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        with catch_warnings(record=True) as log:
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

    @patch("celery.utils.timer2.to_timestamp")
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        m = create_message(Mock(), task=foo_task.name, args=("2, 2"), kwargs={}, eta=datetime.now().isoformat())
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertTrue(m.acknowledged)
        self.assertTrue(to_timestamp.call_count)

    def test_receive_message_InvalidTaskError(self):
        logger = Mock()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger, send_events=False)
        m = create_message(Mock(), task=foo_task.name, args=(1, 2), kwargs="foobarbaz", id=1)
        l.event_dispatcher = Mock()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Received invalid task message", logger.error.call_args[0][0])

    def test_on_decode_error(self):
        logger = Mock()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger, send_events=False)

        class MockMessage(Mock):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", logger.critical.call_args[0][0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        m = create_message(Mock(), task=foo_task.name, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_start_connection_error(self):
        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False, pool=BasePool())
        l.connection_errors = (KeyError,)
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()
        l.priority_timer.stop()

    def test_consume_messages(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, 10, l.logger)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        l.qos.decrement()
        l.consume_messages()
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_maybe_conn_error(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.connection_errors = (KeyError,)
        l.channel_errors = (SyntaxError,)
        l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
        l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
        l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
        self.assertRaises(IndexError, l.maybe_conn_error, Mock(side_effect=IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.qos = QoS(None, 10, l.logger)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        m = create_message(Mock(), task=foo_task.name, eta=datetime.now().isoformat(), args=[2, 4, 8], kwargs={})

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.qos.call_count)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger, send_events=False)
        backend = Mock()
        id = gen_unique_id()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8], kwargs={}, id=id)
        from celery.worker.state import revoked

        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        backend = Mock()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = Mock()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False)
        l.event_dispatcher = Mock()
        l.event_dispatcher._outbound_buffer = deque()
        backend = Mock()
        m = create_message(
            backend, task=foo_task.name, args=[2, 4, 8], kwargs={}, eta=(datetime.now() + timedelta(days=1)).isoformat()
        )

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = Mock()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        init_callback = Mock()
        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False, init_callback=init_callback)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(init_callback.call_count)
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger, send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = BrokerConnection()
        l.consume_messages = Mock(side_effect=socket.error("foo"))
        self.assertRaises(socket.error, l.start)
        self.assertTrue(init_callback.call_count)
        self.assertTrue(l.consume_messages.call_count)

    def test_reset_connection_with_no_node(self):

        l = MainConsumer(self.ready_queue, self.eta_schedule, self.logger)
        self.assertEqual(None, l.pool)
        l.reset_connection()
class test_CarrotListener(unittest.TestCase):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Scheduler(self.ready_queue)
        self.logger = get_logger()
        self.logger.setLevel(0)

    def test_mainloop(self):
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)

        class MockConnection(object):
            def drain_events(self):
                return "draining"

        l.connection = MockConnection()
        l.connection.connection = MockConnection()

        it = l._mainloop()
        self.assertTrue(it.next(), "draining")

        records = {}

        def create_recorder(key):
            def _recorder(*args, **kwargs):
                records[key] = True

            return _recorder

        l.task_consumer = PlaceHolder()
        l.task_consumer.iterconsume = create_recorder("consume_tasks")
        l.broadcast_consumer = PlaceHolder()
        l.broadcast_consumer.register_callback = create_recorder(
            "broadcast_callback")
        l.broadcast_consumer.iterconsume = create_recorder("consume_broadcast")
        l.task_consumer.add_consumer = create_recorder("consumer_add")

        records.clear()
        self.assertEqual(l._detect_wait_method(), l._mainloop)
        for record in ("broadcast_callback", "consume_broadcast",
                       "consume_tasks"):
            self.assertTrue(records.get(record))

        records.clear()
        l.connection.connection = PlaceHolder()
        self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
        self.assertTrue(records.get("consumer_add"))

    def test_connection(self):
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_receive_message_control_command(self):
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, control={"command": "shutdown"})
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()
        l.receive_message(m.decode(), m)
        self.assertIn("shutdown", l.control_dispatch.commands)

    def test_close_connection(self):
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        l._state = RUN
        l.close_connection()

        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=(1, 2),
                           kwargs="foobarbaz",
                           id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.control_dispatch = MockControlDispatch()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           logger,
                           send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Message decoding error", logger.logged[0])

    def test_receieve_message(self):
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta_isoformat(self):
        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8],
                           kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.receive_message(m.decode(), m)

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)

    def test_revoke(self):
        ready_queue = FastQueue()
        l = CarrotListener(ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        c = create_message(backend,
                           control={
                               "command": "revoke",
                               "task_id": id
                           })
        t = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(c.decode(), c)
        from celery.worker.state import revoked
        self.assertIn(id, revoked)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = CarrotListener(self.ready_queue,
                           self.eta_schedule,
                           self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           eta=(datetime.now() +
                                timedelta(days=1)).isoformat())

        l.reset_connection()
        p, conf.BROKER_CONNECTION_RETRY = conf.BROKER_CONNECTION_RETRY, False
        try:
            l.reset_connection()
        finally:
            conf.BROKER_CONNECTION_RETRY = p
        l.receive_message(m.decode(), m)

        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 4)
        eta, priority, task, on_accept = in_hold
        self.assertIsInstance(task, TaskRequest)
        self.assertTrue(callable(on_accept))
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            next = 4

            def update(self):
                self.prev = self.next

        class _Listener(CarrotListener):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

            def _detect_wait_method(self):
                return self.wait_method

        called_back = [False]

        def init_callback(listener):
            called_back[0] = True

        l = _Listener(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()

        def raises_KeyError(limit=None):
            yield True
            l.iterations = 1
            raise KeyError("foo")

        l.wait_method = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.next)

        l = _Listener(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()

        def raises_socket_error(limit=None):
            yield True
            l.iterations = 1
            raise socket.error("foo")

        l.wait_method = raises_socket_error
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Пример #9
0
class test_Consumer(unittest.TestCase):

    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = app_or_default().log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        l.qos = QoS(l.task_consumer, 10, l.logger)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = app_or_default().broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_eta_OverflowError(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        backend = MockBackend()
        called = [False]

        def to_timestamp(d):
            called[0] = True
            raise OverflowError()

        m = create_message(backend, task=foo_task.name,
                                    args=("2, 2"),
                                    kwargs={},
                                    eta=datetime.now().isoformat())
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        from celery.worker import consumer
        prev, consumer.to_timestamp = consumer.to_timestamp, to_timestamp
        try:
            l.receive_message(m.decode(), m)
            self.assertTrue(m.acknowledged)
            self.assertTrue(called[0])
        finally:
            consumer.to_timestamp = prev

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
            args=(1, 2), kwargs="foobarbaz", id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
                           send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Can't decode message body", logger.logged[0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_start_connection_error(self):

        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        l.connection_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()

    def test_consume_messages(self):
        app = app_or_default()

        class Connection(app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        class Consumer(object):
            consuming = False
            prefetch_count = 0

            def consume(self):
                self.consuming = True

            def qos(self, prefetch_size=0, prefetch_count=0,
                            apply_global=False):
                self.prefetch_count = prefetch_count

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Consumer()
        l.broadcast_consumer = Consumer()
        l.qos = QoS(l.task_consumer, 10, l.logger)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consuming)
        self.assertTrue(l.broadcast_consumer.consuming)
        self.assertEqual(l.task_consumer.prefetch_count, 10)

        l.qos.decrement()
        l.consume_messages()
        self.assertEqual(l.task_consumer.prefetch_count, 9)

    def test_maybe_conn_error(self):

        def raises(error):

            def fun():
                raise error

            return fun

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(raises(AttributeError("foo")))
        l.maybe_conn_error(raises(KeyError("foo")))
        l.maybe_conn_error(raises(SyntaxError("foo")))
        self.assertRaises(IndexError, l.maybe_conn_error,
                raises(IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        l.qos = QoS(None, 10, l.logger)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):

        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                             send_events=False)
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8], kwargs={})

        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                          send_events=False)
        l.event_dispatcher = MockEventDispatcher()
        backend = MockBackend()
        m = create_message(backend, task=foo_task.name,
                           args=[2, 4, 8], kwargs={},
                           eta=(datetime.now() +
                               timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        called_back = [False]

        def init_callback(consumer):
            called_back[0] = True

        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.task_consumer = MockConsumer()
        l.broadcast_consumer = MockConsumer()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = MockConsumer()
        l.broadcast_consumer = MockConsumer()
        l.connection = BrokerConnection()

        def raises_socket_error(limit=None):
            l.iterations = 1
            raise socket.error("foo")

        l.consume_messages = raises_socket_error
        self.assertRaises(socket.error, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
Пример #10
0
class test_Consumer(unittest.TestCase):
    def setUp(self):
        self.ready_queue = FastQueue()
        self.eta_schedule = Timer()
        self.logger = current_app.log.get_default_logger()
        self.logger.setLevel(0)

    def tearDown(self):
        self.eta_schedule.stop()

    def test_info(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.qos = QoS(l.task_consumer, 10, l.logger)
        info = l.info
        self.assertEqual(info["prefetch_count"], 10)
        self.assertFalse(info["broker"])

        l.connection = current_app.broker_connection()
        info = l.info
        self.assertTrue(info["broker"])

    def test_connection(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)

        l._state = RUN
        l.event_dispatcher = None
        l.stop_consumers(close=False)
        self.assertTrue(l.connection)

        l._state = RUN
        l.stop_consumers()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.reset_connection()
        self.assertIsInstance(l.connection, BrokerConnection)
        l.stop_consumers()

        l.stop()
        l.close_connection()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l._state = RUN
        l.close_connection()

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        eventer = l.event_dispatcher = MockEventDispatcher()
        heart = l.heart = MockHeart()
        l._state = RUN
        l.stop_consumers()
        self.assertTrue(eventer.closed)
        self.assertTrue(heart.closed)

    def test_receive_message_unknown(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend, unknown={"baz": "!!!"})
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        def with_catch_warnings(log):
            l.receive_message(m.decode(), m)
            self.assertTrue(log)
            self.assertIn("unknown message", log[0].message.args[0])

        context = catch_warnings(record=True)
        execute_context(context, with_catch_warnings)

    def test_receive_message_eta_OverflowError(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        called = [False]

        def to_timestamp(d):
            called[0] = True
            raise OverflowError()

        m = create_message(backend,
                           task=foo_task.name,
                           args=("2, 2"),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        prev, timer2.to_timestamp = timer2.to_timestamp, to_timestamp
        try:
            l.receive_message(m.decode(), m)
            self.assertTrue(m.acknowledged)
            self.assertTrue(called[0])
        finally:
            timer2.to_timestamp = prev

    def test_receive_message_InvalidTaskError(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=(1, 2),
                           kwargs="foobarbaz",
                           id=1)
        l.event_dispatcher = MockEventDispatcher()
        l.pidbox_node = MockNode()

        l.receive_message(m.decode(), m)
        self.assertIn("Invalid task ignored", logger.logged[0])

    def test_on_decode_error(self):
        logger = MockLogger()
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            logger,
                            send_events=False)

        class MockMessage(object):
            content_type = "application/x-msgpack"
            content_encoding = "binary"
            body = "foobarbaz"
            acked = False

            def ack(self):
                self.acked = True

        message = MockMessage()
        l.on_decode_error(message, KeyError("foo"))
        self.assertTrue(message.acked)
        self.assertIn("Can't decode message body", logger.logged[0])

    def test_receieve_message(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)

        in_bucket = self.ready_queue.get_nowait()
        self.assertIsInstance(in_bucket, TaskRequest)
        self.assertEqual(in_bucket.task_name, foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.eta_schedule.empty())

    def test_start_connection_error(self):
        class MockConsumer(MainConsumer):
            iterations = 0

            def consume_messages(self):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError("foo")
                raise SyntaxError("bar")

        l = MockConsumer(self.ready_queue,
                         self.eta_schedule,
                         self.logger,
                         send_events=False,
                         pool=BasePool())
        l.connection_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.heart.stop()

    def test_consume_messages(self):
        class Connection(current_app.broker_connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        class Consumer(object):
            consuming = False
            prefetch_count = 0

            def consume(self):
                self.consuming = True

            def qos(self,
                    prefetch_size=0,
                    prefetch_count=0,
                    apply_global=False):
                self.prefetch_count = prefetch_count

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Consumer()
        l.qos = QoS(l.task_consumer, 10, l.logger)

        l.consume_messages()
        l.consume_messages()
        self.assertTrue(l.task_consumer.consuming)
        self.assertEqual(l.task_consumer.prefetch_count, 10)

        l.qos.decrement()
        l.consume_messages()
        self.assertEqual(l.task_consumer.prefetch_count, 9)

    def test_maybe_conn_error(self):
        def raises(error):
            def fun():
                raise error

            return fun

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.connection_errors = (KeyError, )
        l.channel_errors = (SyntaxError, )
        l.maybe_conn_error(raises(AttributeError("foo")))
        l.maybe_conn_error(raises(KeyError("foo")))
        l.maybe_conn_error(raises(SyntaxError("foo")))
        self.assertRaises(IndexError, l.maybe_conn_error,
                          raises(IndexError("foo")))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.qos = QoS(None, 10, l.logger)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.ready_queue.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        class MockConsumer(object):
            prefetch_count_incremented = False

            def qos(self, **kwargs):
                self.prefetch_count_incremented = True

        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           eta=datetime.now().isoformat(),
                           args=[2, 4, 8],
                           kwargs={})

        l.task_consumer = MockConsumer()
        l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()

        items = [entry[2] for entry in self.eta_schedule.queue]
        found = 0
        for item in items:
            if item.args[0].task_name == foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertTrue(l.task_consumer.prefetch_count_incremented)
        l.eta_schedule.stop()

    def test_revoke(self):
        ready_queue = FastQueue()
        l = MyKombuConsumer(ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        id = gen_unique_id()
        t = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        l.receive_message(t.decode(), t)
        self.assertTrue(ready_queue.empty())

    def test_receieve_message_not_registered(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        backend = MockBackend()
        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})

        l.event_dispatcher = MockEventDispatcher()
        self.assertFalse(l.receive_message(m.decode(), m))
        self.assertRaises(Empty, self.ready_queue.get_nowait)
        self.assertTrue(self.eta_schedule.empty())

    def test_receieve_message_eta(self):
        l = MyKombuConsumer(self.ready_queue,
                            self.eta_schedule,
                            self.logger,
                            send_events=False)
        l.event_dispatcher = MockEventDispatcher()
        backend = MockBackend()
        m = create_message(backend,
                           task=foo_task.name,
                           args=[2, 4, 8],
                           kwargs={},
                           eta=(datetime.now() +
                                timedelta(days=1)).isoformat())

        l.reset_connection()
        p = l.app.conf.BROKER_CONNECTION_RETRY
        l.app.conf.BROKER_CONNECTION_RETRY = False
        try:
            l.reset_connection()
        finally:
            l.app.conf.BROKER_CONNECTION_RETRY = p
        l.stop_consumers()
        l.event_dispatcher = MockEventDispatcher()
        l.receive_message(m.decode(), m)
        l.eta_schedule.stop()
        in_hold = self.eta_schedule.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, TaskRequest)
        self.assertEqual(task.task_name, foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        self.assertRaises(Empty, self.ready_queue.get_nowait)

    def test_start__consume_messages(self):
        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0
            wait_method = None

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError("foo")

        called_back = [False]

        def init_callback(consumer):
            called_back[0] = True

        l = _Consumer(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.task_consumer = MockConsumer()
        l.broadcast_consumer = MockConsumer()
        l.qos = _QoS()
        l.connection = BrokerConnection()
        l.iterations = 0

        def raises_KeyError(limit=None):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError("foo")

        l.consume_messages = raises_KeyError
        self.assertRaises(KeyError, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)
        self.assertEqual(l.qos.prev, l.qos.value)

        l = _Consumer(self.ready_queue,
                      self.eta_schedule,
                      self.logger,
                      send_events=False,
                      init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = MockConsumer()
        l.broadcast_consumer = MockConsumer()
        l.connection = BrokerConnection()

        def raises_socket_error(limit=None):
            l.iterations = 1
            raise socket.error("foo")

        l.consume_messages = raises_socket_error
        self.assertRaises(socket.error, l.start)
        self.assertTrue(called_back[0])
        self.assertEqual(l.iterations, 1)