def test(self):
     obj = AMQPThreadedPusher.__new__(AMQPThreadedPusher)
     with rlocked_patch.object(AMQPThreadedPusher,
                               'shutdown') as shutdown_mock:
         with obj:
             self.assertEqual(shutdown_mock.mock_calls, [])
         self.assertEqual(shutdown_mock.mock_calls, [call()])
 def _make_obj(self, **kw):
     # create and initialize a usable AMQPThreadedPusher instance
     self.obj = AMQPThreadedPusher(
         connection_params_dict={'conn_param': sen.param_value},
         exchange={'exchange': sen.exchange},
         queues_to_declare=[
             sen.queue1,
             {
                 'blabla': sen.blabla
             },
             {
                 'blabla': sen.blabla,
                 'callback': sen.callback
             },
         ],
         serialize=self.serialize,
         prop_kwargs={'prop_kwarg': sen.prop_value},
         mandatory=sen.mandatory,
         output_fifo_max_size=3,
         error_callback=self.error_callback,
         **kw)
Example #3
0
    def __init__(self,
                 connection_params_dict,
                 exchange='logging',
                 exchange_declare_kwargs=None,
                 rk_template=None,
                 prop_kwargs=None,
                 other_pusher_kwargs=None,
                 error_logger_name=None,
                 msg_count_window=None,
                 msg_count_max=None,
                 **super_kwargs):
        if exchange_declare_kwargs is None:
            exchange_declare_kwargs = self.DEFAULT_EXCHANGE_DECLARE_KWARGS
        if rk_template is None:
            rk_template = self.DEFAULT_RK_TEMPLATE
        if prop_kwargs is None:
            prop_kwargs = self.DEFAULT_PROP_KWARGS
        if other_pusher_kwargs is None:
            other_pusher_kwargs = {}
        if error_logger_name is None:
            error_logger_name = self.ERROR_LOGGER
        if msg_count_window is None:
            msg_count_window = self.DEFAULT_MSG_COUNT_WINDOW
        if msg_count_max is None:
            msg_count_max = self.DEFAULT_MSG_COUNT_MAX

        super(AMQPHandler, self).__init__(**super_kwargs)

        self._rk_template = rk_template
        self._msg_count_window = msg_count_window
        self._msg_count_max = msg_count_max

        # error logging tools
        self._error_fifo = error_fifo = Queue.Queue()
        self._error_logger_name = error_logger_name
        self._error_logging_thread = threading.Thread(
            target=self._error_logging_loop,
            kwargs=dict(error_fifo=self._error_fifo,
                        error_logger=logging.getLogger(error_logger_name)))
        self._error_logging_thread.daemon = True
        self._closing = False

        def error_callback(exc):
            try:
                exc_info = sys.exc_info()
                assert exc_info[1] is exc
                error_msg = make_condensed_debug_msg(exc_info)
                error_fifo.put_nowait(error_msg)
            finally:
                # (to break any traceback-related reference cycle)
                exc_info = None

        # pusher instance
        self._pusher = AMQPThreadedPusher(
            connection_params_dict=connection_params_dict,
            exchange=dict(exchange_declare_kwargs, exchange=exchange),
            prop_kwargs=prop_kwargs,
            serialize=self._make_record_serializer(),
            error_callback=error_callback,
            **other_pusher_kwargs)

        # start error logging co-thread
        self._error_logging_thread.start()
Example #4
0
class AMQPHandler(logging.Handler):

    ERROR_LOGGER = 'AMQP_LOGGING_HANDLER_ERRORS'
    LOGRECORD_EXTRA_ATTRS = {
        'py_ver': '.'.join(map(str, sys.version_info)),
        'py_64bits': (sys.maxsize > 2**32),
        'py_ucs4': (sys.maxunicode > 0xffff),
        'py_platform': sys.platform,
        'hostname': HOSTNAME,
        'script_basename': SCRIPT_BASENAME,
    }
    LOGRECORD_KEY_MAX_LENGTH = 256

    DEFAULT_MSG_COUNT_WINDOW = 300
    DEFAULT_MSG_COUNT_MAX = 100

    DEFAULT_EXCHANGE_DECLARE_KWARGS = {'exchange_type': 'topic'}
    DEFAULT_RK_TEMPLATE = '{hostname}.{script_basename}.{levelname}.{loggername}'
    DEFAULT_PROP_KWARGS = dict(
        content_type='application/json',
        delivery_mode=1,
    )

    def __init__(self,
                 connection_params_dict,
                 exchange='logging',
                 exchange_declare_kwargs=None,
                 rk_template=None,
                 prop_kwargs=None,
                 other_pusher_kwargs=None,
                 error_logger_name=None,
                 msg_count_window=None,
                 msg_count_max=None,
                 **super_kwargs):
        if exchange_declare_kwargs is None:
            exchange_declare_kwargs = self.DEFAULT_EXCHANGE_DECLARE_KWARGS
        if rk_template is None:
            rk_template = self.DEFAULT_RK_TEMPLATE
        if prop_kwargs is None:
            prop_kwargs = self.DEFAULT_PROP_KWARGS
        if other_pusher_kwargs is None:
            other_pusher_kwargs = {}
        if error_logger_name is None:
            error_logger_name = self.ERROR_LOGGER
        if msg_count_window is None:
            msg_count_window = self.DEFAULT_MSG_COUNT_WINDOW
        if msg_count_max is None:
            msg_count_max = self.DEFAULT_MSG_COUNT_MAX

        super(AMQPHandler, self).__init__(**super_kwargs)

        self._rk_template = rk_template
        self._msg_count_window = msg_count_window
        self._msg_count_max = msg_count_max

        # error logging tools
        self._error_fifo = error_fifo = Queue.Queue()
        self._error_logger_name = error_logger_name
        self._error_logging_thread = threading.Thread(
            target=self._error_logging_loop,
            kwargs=dict(error_fifo=self._error_fifo,
                        error_logger=logging.getLogger(error_logger_name)))
        self._error_logging_thread.daemon = True
        self._closing = False

        def error_callback(exc):
            try:
                exc_info = sys.exc_info()
                assert exc_info[1] is exc
                error_msg = make_condensed_debug_msg(exc_info)
                error_fifo.put_nowait(error_msg)
            finally:
                # (to break any traceback-related reference cycle)
                exc_info = None

        # pusher instance
        self._pusher = AMQPThreadedPusher(
            connection_params_dict=connection_params_dict,
            exchange=dict(exchange_declare_kwargs, exchange=exchange),
            prop_kwargs=prop_kwargs,
            serialize=self._make_record_serializer(),
            error_callback=error_callback,
            **other_pusher_kwargs)

        # start error logging co-thread
        self._error_logging_thread.start()

    @classmethod
    def _error_logging_loop(cls, error_fifo, error_logger):
        try:
            while True:
                error_msg = error_fifo.get()
                error_logger.error('%s', error_msg)
        except:
            dump_condensed_debug_msg(
                'ERROR LOGGING CO-THREAD STOPS WITH EXCEPTION!')
            raise  # traceback should be printed to sys.stderr automatically

    def _make_record_serializer(self):
        defaultdict = collections.defaultdict
        formatter = logging.Formatter()
        json_encode = json.JSONEncoder(default=reprlib.repr).encode
        record_attrs_proto = self.LOGRECORD_EXTRA_ATTRS
        record_key_max_length = self.LOGRECORD_KEY_MAX_LENGTH
        match_useless_stack_item_regex = re.compile(
            r'  File "[ \S]*/python[0-9.]+/logging/__init__\.py\w?"').match

        msg_count_window = self._msg_count_window
        msg_count_max = self._msg_count_max
        cur_window_cell = [
            None
        ]  # using 1-item list as a cell for a writable non-local variable
        loggername_to_window_and_msg_to_count = defaultdict(
            lambda: (cur_window_cell[0], defaultdict(lambda: -msg_count_max)))

        def _should_publish(record):
            # if, within the particular time window (window length is
            # defined as `msg_count_window`, in seconds), the number of
            # records from the particular logger that contain the same
            # `msg` (note: *not* necessarily the same `message`!)
            # exceeds the limit (defined as `msg_count_max`) --
            # further records containing that `msg` and originating from
            # that logger are skipped until *any* record from that
            # logger appears within *another* time window...
            loggername = record.name
            msg = record.msg
            cur_window_cell[
                0] = cur_window = record.created // msg_count_window
            window, msg_to_count = loggername_to_window_and_msg_to_count[
                loggername]
            if window != cur_window:
                # new time window for this logger
                # => attach (as the `msg_skipped_to_count` record
                #    attribute) the info about skipped messages (if
                #    any) and update/flush the state mappings
                msg_skipped_to_count = dict(
                    (m, c) for m, c in msg_to_count.iteritems() if c > 0)
                if msg_skipped_to_count:
                    record.msg_skipped_to_count = msg_skipped_to_count
                loggername_to_window_and_msg_to_count[
                    loggername] = cur_window, msg_to_count
                msg_to_count.clear()
            msg_to_count[msg] = count = msg_to_count[msg] + 1
            if count <= 0:
                if count == 0:
                    # this is the last record (from the particular
                    # logger + containing the particular `msg`) in
                    # the current time window that is *not* skipped
                    # => so it obtains the `msg_reached_count_max`
                    #    attribute (equal to `msg_count_max` value)
                    record.msg_reached_count_max = msg_count_max
                return True
            else:
                return False

        def serialize_record(record):
            if not _should_publish(record):
                return DoNotPublish
            record_attrs = record_attrs_proto.copy()
            record_attrs.update((
                limit_string(ascii_str(key), char_limit=record_key_max_length),
                value) for key, value in vars(record).items())
            record_attrs['message'] = record.getMessage()
            record_attrs['asctime'] = formatter.formatTime(record)
            exc_info = record_attrs.pop('exc_info', None)
            if exc_info:
                if not record_attrs['exc_text']:
                    record_attrs['exc_text'] = formatter.formatException(
                        exc_info)
                record_attrs['exc_type_repr'] = repr(exc_info[0])
                record_attrs['exc_ascii_str'] = ascii_str(exc_info[1])
            stack_items = record_attrs.pop('formatted_call_stack_items', None)
            if stack_items:
                del stack_items[
                    -1]  # this item is from this AMQPHandler.emit()
                while stack_items and match_useless_stack_item_regex(
                        stack_items[-1]):
                    del stack_items[-1]
                record_attrs['formatted_call_stack'] = ''.join(stack_items)
            return json_encode(record_attrs)

        return serialize_record

    def emit(self, record, _ERROR_LEVEL_NO=logging.ERROR):
        try:
            # ignore internal AMQP-handler-related error messages --
            # i.e., those logged with the handler's error logger
            # (to avoid infinite loop of message emissions...)
            if record.name == self._error_logger_name:
                return
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            logging.Handler.handleError(self, record)

        try:
            if record.levelno >= _ERROR_LEVEL_NO:
                record.formatted_call_stack_items = traceback.format_stack()
            routing_key = self._rk_template.format(
                hostname=HOSTNAME,
                script_basename=SCRIPT_BASENAME,
                levelname=record.levelname,
                loggername=record.name)
            try:
                self._pusher.push(record, routing_key)
            except ValueError:
                if not self._closing:
                    raise
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            self.handleError(record)

    def close(self):
        # typically, this method is called at interpreter exit
        # (by logging.shutdown() which is always registered with
        # atexit.register() machinery)
        try:
            try:
                super(AMQPHandler, self).close()
                self._closing = True
                self._pusher.shutdown()
            except:
                dump_condensed_debug_msg(
                    'EXCEPTION DURING EXECUTION OF close() OF THE AMQP LOGGING HANDLER!'
                )
                raise
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            exc = sys.exc_info()[1]
            self._error_fifo.put(exc)

    def handleError(self, record):
        try:
            exc = sys.exc_info()[1]
            self._error_fifo.put(exc)
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            logging.Handler.handleError(self, record)
class TestAMQPThreadedPusher_internal_cooperation(unittest.TestCase):
    def setUp(self):
        # patching some global objects and preparing mockups of them
        self._stderr_patcher = rlocked_patch('sys.stderr')
        self.stderr_mock = self._stderr_patcher.start()
        self.addCleanup(self._stderr_patcher.stop)

        # patching some imported modules and preparing mockups of them
        self._time_patcher = rlocked_patch('n6lib.amqp_getters_pushers.time')
        self.time_mock = self._time_patcher.start()
        self.addCleanup(self._time_patcher.stop)

        self._traceback_patcher = rlocked_patch(
            'n6lib.amqp_getters_pushers.traceback')
        self.traceback_mock = self._traceback_patcher.start()
        self.addCleanup(self._traceback_patcher.stop)

        self._pika_patcher = rlocked_patch('n6lib.amqp_getters_pushers.pika')
        self.pika_mock = self._pika_patcher.start()
        self.addCleanup(self._pika_patcher.stop)

        # preparing sentinel exceptions
        class AMQPConnectionError_sentinel_exc(Exception):
            pass

        class ConnectionClosed_sentinel_exc(Exception):
            pass

        class generic_sentinel_exc(Exception):
            pass

        self.AMQPConnectionError_sentinel_exc = AMQPConnectionError_sentinel_exc
        self.ConnectionClosed_sentinel_exc = ConnectionClosed_sentinel_exc
        self.generic_sentinel_exc = generic_sentinel_exc

        # preparing mockups of different objects
        self.conn_mock = RLockedMagicMock()
        self.channel_mock = RLockedMagicMock()
        self.optional_setup_communication_mock = RLockedMagicMock()
        self.serialize = RLockedMagicMock()
        self.error_callback = RLockedMagicMock()

        # configuring the mockups
        self.pika_mock.exceptions.AMQPConnectionError = AMQPConnectionError_sentinel_exc
        self.pika_mock.exceptions.ConnectionClosed = ConnectionClosed_sentinel_exc
        self.pika_mock.ConnectionParameters.return_value = sen.conn_parameters
        self.pika_mock.BlockingConnection.side_effect = [
            AMQPConnectionError_sentinel_exc,
            self.conn_mock,
        ]
        self.pika_mock.BasicProperties.return_value = sen.props
        self.conn_mock.channel.return_value = self.channel_mock
        self.serialize.side_effect = (lambda data: data)

    #
    # Fixture helpers and reusable assertions

    def _reset_mocks(self):
        self.pika_mock.reset_mock()
        self.time_mock.reset_mock()
        self.traceback_mock.reset_mock()
        self.stderr_mock.reset_mock()
        self.conn_mock.reset_mock()
        self.channel_mock.reset_mock()
        self.optional_setup_communication_mock.reset_mock()
        if self.serialize is not None:
            self.serialize.reset_mock()
        if self.error_callback is not None:
            self.error_callback.reset_mock()

    def _make_obj(self, **kw):
        # create and initialize a usable AMQPThreadedPusher instance
        self.obj = AMQPThreadedPusher(
            connection_params_dict={'conn_param': sen.param_value},
            exchange={'exchange': sen.exchange},
            queues_to_declare=[
                sen.queue1,
                {
                    'blabla': sen.blabla
                },
                {
                    'blabla': sen.blabla,
                    'callback': sen.callback
                },
            ],
            serialize=self.serialize,
            prop_kwargs={'prop_kwarg': sen.prop_value},
            mandatory=sen.mandatory,
            output_fifo_max_size=3,
            error_callback=self.error_callback,
            **kw)

    def _side_effect_for_publish(self, exception_seq):
        orig_publish = AMQPThreadedPusher._publish
        exceptions = iter(exception_seq)

        def _side_effect(data, routing_key, custom_prop_kwargs):
            exc = next(exceptions)
            if exc is None:
                return orig_publish(self.obj, data, routing_key,
                                    custom_prop_kwargs)
            else:
                raise exc

        return _side_effect

    def _mock_setup_communication(self):
        self.obj._setup_communication = self.optional_setup_communication_mock

    def _assert_setup_done(self):
        self.assertEqual(
            self.pika_mock.mock_calls,
            [
                call.ConnectionParameters(
                    conn_param=sen.param_value,
                    client_properties={
                        'information': CONN_PARAM_CLIENT_PROP_INFORMATION
                    }),
                call.BlockingConnection(sen.conn_parameters),
                # repeated after pika.exceptions.AMQPConnectionError:
                call.ConnectionParameters(
                    conn_param=sen.param_value,
                    client_properties={
                        'information': CONN_PARAM_CLIENT_PROP_INFORMATION
                    }),
                call.BlockingConnection(sen.conn_parameters),
            ])
        self.assertEqual(
            self.time_mock.sleep.mock_calls,
            [
                # after pika.exceptions.AMQPConnectionError
                call(0.5),  # 0.5 == CONNECTION_RETRY_DELAY
            ])
        self.assertIs(self.obj._connection, self.conn_mock)
        self.assertIs(self.obj._channel, self.channel_mock)
        self.assertEqual(self.channel_mock.mock_calls, [
            call.exchange_declare(exchange=sen.exchange),
            call.queue_declare(queue=sen.queue1, callback=ANY),
            call.queue_declare(blabla=sen.blabla, callback=ANY),
            call.queue_declare(blabla=sen.blabla, callback=sen.callback),
        ])
        # some additional marginal asserts:
        self.assertFalse(self.traceback_mock.print_exc.mock_calls)
        self.assertFalse(self.stderr_mock.mock_calls)
        self.assertFalse(self.obj._connection_closed)

    def _assert_publishing_started(self):
        self.assertTrue(self.obj._publishing)
        self.assertTrue(self.obj._publishing_thread.is_alive())

    def _assert_shut_down(self):
        self.assertEqual(self.conn_mock.close.mock_calls, [call()])
        self.assertFalse(self.obj._publishing_thread.is_alive())
        self.assertFalse(self.obj._publishing)
        self.assertTrue(self.obj._connection_closed)

    def _assert_no_remaining_data(self):
        self.assertIn(self.obj._output_fifo.queue, [
            collections.deque(),
            collections.deque([None]),
        ])

    @contextlib.contextmanager
    def _testing_normal_push(self, error_callback_call_count=0):
        self._make_obj()
        try:
            self._assert_setup_done()
            self._assert_publishing_started()

            yield self.obj

            while self.channel_mock.basic_publish.call_count < 2:
                time.sleep(0.01)
        finally:
            self.obj.shutdown()

        # properties of both published messages have been created properly
        # (using also custom prop kwargs if given)
        self.assertEqual(self.pika_mock.BasicProperties.mock_calls, [
            call(prop_kwarg=sen.prop_value),
            call(prop_kwarg=sen.prop_value, custom=sen.custom_value),
        ])

        # both messages have been published properly
        self.assertEqual(self.channel_mock.basic_publish.mock_calls, [
            call(
                exchange=sen.exchange,
                routing_key=sen.rk1,
                body=sen.data1,
                properties=sen.props,
                mandatory=sen.mandatory,
            ),
            call(
                exchange=sen.exchange,
                routing_key=sen.rk2,
                body=sen.data2,
                properties=sen.props,
                mandatory=sen.mandatory,
            ),
        ])

        self.assertEqual(self.error_callback.call_count,
                         error_callback_call_count)
        self._assert_shut_down()
        self._assert_no_remaining_data()

        self.assertFalse(self.traceback_mock.print_exc.mock_calls)
        self.assertFalse(self.stderr_mock.mock_calls)

        # cannot push (as the pusher has been shut down)
        with self.assertRaises(ValueError):
            self.obj.push(sen.data3, sen.rk3)

    def _error_case_commons(self, subcall_mock, expected_subcall_count):
        self._make_obj()
        try:
            self._assert_setup_done()
            self._assert_publishing_started()

            self._reset_mocks()
            self._mock_setup_communication()

            self.obj.push(sen.data, sen.rk)

            # we must delay shutdown() to let the pub. thread operate
            while subcall_mock.call_count < expected_subcall_count:
                time.sleep(0.01)
        finally:
            self.obj.shutdown()

        self.assertEqual(self.serialize.mock_calls, [call(sen.data)])

        self._assert_shut_down()
        self._assert_no_remaining_data()

    #
    # Actual tests

    def test_normal_operation_without_serialize(self):
        self.serialize = None
        with self._testing_normal_push() as obj:
            obj.push(sen.data1, sen.rk1)
            obj.push(sen.data2, sen.rk2, {'custom': sen.custom_value})

    def test_normal_operation_with_serialize(self):
        self.serialize.side_effect = [
            sen.data1,
            DoNotPublish,
            sen.data2,
        ]
        with self._testing_normal_push() as obj:
            # published normally
            obj.push(sen.raw1, sen.rk1)

            # serialize() returns DoNotPublish for this one (see above)
            # so the data will not be published
            obj.push(sen.no_pub_data, sen.no_pub_rk)

            # published normally
            obj.push(sen.raw2, sen.rk2, {'custom': sen.custom_value})

        self.assertEqual(self.serialize.mock_calls, [
            call(sen.raw1),
            call(sen.no_pub_data),
            call(sen.raw2),
        ])

    def test_publishing_flag_is_False(self):
        # do not try to reconnect on pika.exceptions.ConnectionClosed
        # but continue publishing until the output fifo is empty
        def basic_publish_side_effect(*args, **kwargs):
            time.sleep(0.02)

        self.channel_mock.basic_publish.side_effect = basic_publish_side_effect
        with rlocked_patch(
                'n6lib.amqp_getters_pushers.AMQPThreadedPusher._publish',
                side_effect=self._side_effect_for_publish(exception_seq=[
                    None,
                    self.ConnectionClosed_sentinel_exc,
                    None,
                ]),
        ) as _publish_mock:
            with self._testing_normal_push(error_callback_call_count=1) as obj:
                obj._publishing = False
                obj._output_fifo.put_nowait((sen.data1, sen.rk1, None))
                obj._output_fifo.put_nowait((sen.data_err, sen.rk_err, None))
                obj._output_fifo.put_nowait((sen.data2, sen.rk2, {
                    'custom': sen.custom_value
                }))

        self.assertEqual(_publish_mock.mock_calls, [
            call(sen.data1, sen.rk1, None),
            call(sen.data_err, sen.rk_err, None),
            call(sen.data2, sen.rk2, {'custom': sen.custom_value}),
        ])

        # no reconnections
        self.assertFalse(self.optional_setup_communication_mock.mock_calls)

        # one error callback call
        self.assertEqual(self.error_callback.mock_calls, [call(ANY)])

    def test_permanent_AMQPConnectionError(self):
        self.pika_mock.BlockingConnection.side_effect = self.AMQPConnectionError_sentinel_exc

        with self.assertRaises(self.AMQPConnectionError_sentinel_exc):
            self._make_obj()

        self.assertEqual(
            self.pika_mock.mock_calls,
            # 10 calls because CONNECTION_ATTEMPTS == 10)
            10 * [
                call.ConnectionParameters(
                    conn_param=sen.param_value,
                    client_properties={
                        'information': CONN_PARAM_CLIENT_PROP_INFORMATION
                    }),
                call.BlockingConnection(sen.conn_parameters),
            ])
        self.assertEqual(
            self.time_mock.sleep.mock_calls,
            # (call(0.5) because CONNECTION_RETRY_DELAY == 0.5;
            # 9 calls because CONNECTION_ATTEMPTS == 10
            # and there is no delay after the last attempt)
            9 * [call(0.5)],
        )
        self.assertEqual(self.channel_mock.basic_publish.call_count, 0)
        self.assertEqual(self.error_callback.call_count, 0)
        self.assertFalse(self.traceback_mock.print_exc.mock_calls)
        self.assertFalse(self.stderr_mock.mock_calls)

    def test_publishing_with_one_ConnectionClosed(self):
        exceptions_from_publish = [
            self.ConnectionClosed_sentinel_exc,
            None,
        ]
        expected_publish_call_count = 2

        with rlocked_patch(
                'n6lib.amqp_getters_pushers.AMQPThreadedPusher._publish',
                side_effect=self._side_effect_for_publish(
                    exceptions_from_publish)) as _publish_mock:
            self._error_case_commons(_publish_mock,
                                     expected_publish_call_count)

        self.assertEqual(self.optional_setup_communication_mock.mock_calls,
                         [call()])
        self.assertEqual(_publish_mock.mock_calls, [
            call(sen.data, sen.rk, None),
            call(sen.data, sen.rk, None),
        ])
        self.assertFalse(self.error_callback.mock_calls)
        self.assertFalse(self.traceback_mock.print_exc.mock_calls)
        self.assertFalse(self.stderr_mock.mock_calls)

        # properties of the published message have been created properly...
        self.assertEqual(self.pika_mock.BasicProperties.mock_calls, [
            call(prop_kwarg=sen.prop_value),
        ])
        # ...and the message has been published properly
        self.assertEqual(self.channel_mock.basic_publish.mock_calls, [
            call(
                exchange=sen.exchange,
                routing_key=sen.rk,
                body=sen.data,
                properties=sen.props,
                mandatory=sen.mandatory,
            ),
        ])

    def test_publishing_with_exceptions_and_error_callback(self):
        exceptions_from_publish = [
            self.ConnectionClosed_sentinel_exc,
            TypeError,
        ]
        expected_publish_call_count = 2

        with rlocked_patch(
                'n6lib.amqp_getters_pushers.AMQPThreadedPusher._publish',
                side_effect=self._side_effect_for_publish(
                    exceptions_from_publish)) as _publish_mock:
            self._error_case_commons(_publish_mock,
                                     expected_publish_call_count)

        self.assertEqual(_publish_mock.mock_calls, [
            call(sen.data, sen.rk, None),
            call(sen.data, sen.rk, None),
        ])
        self.assertEqual(self.optional_setup_communication_mock.mock_calls,
                         [call()])
        self.assertEqual(self.error_callback.mock_calls, [call(ANY)])
        self.assertFalse(self.traceback_mock.print_exc.mock_calls)
        self.assertFalse(self.stderr_mock.mock_calls)

        # the message has not been published
        self.assertFalse(self.pika_mock.BasicProperties.mock_calls)
        self.assertFalse(self.channel_mock.basic_publish.mock_calls)

    def test_publishing_with_exceptions_and_no_error_callback(self):
        self.error_callback = None

        exceptions_from_publish = [
            self.ConnectionClosed_sentinel_exc,
            TypeError,
        ]
        expected_publish_call_count = 2

        with rlocked_patch(
                'n6lib.amqp_getters_pushers.AMQPThreadedPusher._publish',
                side_effect=self._side_effect_for_publish(
                    exceptions_from_publish)) as _publish_mock:
            self._error_case_commons(_publish_mock,
                                     expected_publish_call_count)

        assert self.error_callback is None, "bug in test case"

        self.assertIsNone(self.obj._error_callback)
        self.assertEqual(_publish_mock.mock_calls, [
            call(sen.data, sen.rk, None),
            call(sen.data, sen.rk, None),
        ])
        self.assertEqual(self.optional_setup_communication_mock.mock_calls,
                         [call()])
        self.assertEqual(self.traceback_mock.print_exc.mock_calls, [call()])

        # the message has not been published
        self.assertFalse(self.pika_mock.BasicProperties.mock_calls)
        self.assertFalse(self.channel_mock.basic_publish.mock_calls)

    def test_publishing_with_error_callback_raising_exception(self):
        self.error_callback.side_effect = TypeError

        exceptions_from_publish = [
            self.ConnectionClosed_sentinel_exc,
            TypeError,
        ]
        expected_publish_call_count = 2

        with rlocked_patch(
                'n6lib.amqp_getters_pushers.AMQPThreadedPusher._publish',
                side_effect=self._side_effect_for_publish(
                    exceptions_from_publish)) as _publish_mock:
            self._error_case_commons(_publish_mock,
                                     expected_publish_call_count)

        self.assertEqual(_publish_mock.mock_calls, [
            call(sen.data, sen.rk, None),
            call(sen.data, sen.rk, None),
        ])
        self.assertEqual(self.optional_setup_communication_mock.mock_calls,
                         [call()])
        self.assertEqual(self.error_callback.mock_calls, [call(ANY)])
        self.assertEqual(self.traceback_mock.print_exc.mock_calls, [call()])
        self.assertTrue(self.stderr_mock.mock_calls
                        )  # `print(..., file=sys.stderr)` used...

        # the message has not been published
        self.assertFalse(self.pika_mock.BasicProperties.mock_calls)
        self.assertFalse(self.channel_mock.basic_publish.mock_calls)

    def test_serialization_error(self):
        self.serialize.side_effect = TypeError
        expected_serialize_call_count = 1
        with rlocked_patch(
                'n6lib.amqp_getters_pushers.AMQPThreadedPusher._publish'
        ) as _publish_mock:
            self._error_case_commons(self.serialize,
                                     expected_serialize_call_count)

        self.assertEqual(self.serialize.mock_calls, [call(sen.data)])
        self.assertEqual(self.error_callback.mock_calls, [call(ANY)])
        self.assertFalse(self.traceback_mock.print_exc.mock_calls)
        self.assertFalse(self.stderr_mock.mock_calls)
        self.assertFalse(self.optional_setup_communication_mock.mock_calls)

        # the message has not been published
        self.assertFalse(_publish_mock.mock_calls)
        self.assertFalse(self.pika_mock.BasicProperties.mock_calls)
        self.assertFalse(self.channel_mock.basic_publish.mock_calls)

    def test_publishing_with_fatal_error(self):
        with rlocked_patch(
                'n6lib.amqp_getters_pushers.AMQPThreadedPusher._publish',
                # not an Exception subclass:
                side_effect=BaseException,
        ) as _publish_mock:
            self._make_obj()
            try:
                self.obj.push(sen.data, sen.rk)

                # we must wait to let the pub. thread operate and crash
                while _publish_mock.call_count < 1:
                    time.sleep(0.01)
                self.obj._publishing_thread.join(15.0)

                self.assertFalse(self.obj._publishing_thread.is_alive())
                self.assertFalse(self.obj._publishing)

                self.assertTrue(self.stderr_mock.mock_calls)
            finally:
                self.obj._publishing_thread.join(15.0)
                if self.obj._publishing_thread.is_alive():
                    raise RuntimeError('unexpected problem: the publishing '
                                       'thread did not terminate :-/')

        self.assertFalse(self.error_callback.mock_calls)
        self.assertFalse(self.traceback_mock.print_exc.mock_calls)

    def test_publishing_with_fatal_error_and_remaining_data_in_fifo(self):
        def serialize_side_effect(data):
            self.obj._output_fifo.put(sen.item)
            return data

        self.serialize.side_effect = serialize_side_effect

        self.test_publishing_with_fatal_error()

        assert (hasattr(self.obj._output_fifo, 'queue') and isinstance(
            self.obj._output_fifo.queue,
            collections.deque)), "test case's assumption is invalid"
        underlying_deque = self.obj._output_fifo.queue
        self.assertEqual(underlying_deque, collections.deque([sen.item]))
        with self.assertRaisesRegex(ValueError, 'pending messages'):
            self.obj.shutdown()
        self.assertEqual(underlying_deque, collections.deque([sen.item]))

    def test_shutting_down_with_timeouted_join_to_publishing_thread(self):
        self._make_obj(publishing_thread_join_timeout=0.2)
        try:
            self._assert_setup_done()
            self._assert_publishing_started()

            output_fifo_put_nowait_orig = self.obj._output_fifo.put_nowait
            output_fifo_put_nowait_mock = Mock()

            # we must wait to let the pub. thread set the heartbeat flag
            # to True (and then that thread will hang on the fifo)
            while not self.obj._publishing_thread_heartbeat_flag:
                time.sleep(0.01)
        except:
            # (to make the pub. thread terminate on any error)
            self.obj.shutdown()
            raise

        try:
            # monkey-patching output_fifo.put_nowait() so that
            # shutdown() will *not* wake-up the pub. thread
            self.obj._output_fifo.put_nowait = output_fifo_put_nowait_mock

            with self.assertRaisesRegex(
                    RuntimeError, 'pushing thread seems to be still alive'):
                self.obj.shutdown()

            # shutdown() returned because the join timeout expired and
            # heartbeat flag was not re-set to True by the pub. thread
            self.assertFalse(self.obj._publishing_thread_heartbeat_flag)

            # the pusher is shut down...
            self.assertEqual(self.conn_mock.close.mock_calls, [call()])
            self.assertFalse(self.obj._publishing)
            self.assertTrue(self.obj._connection_closed)

            # ...but the pub. thread is still alive
            self.assertTrue(self.obj._publishing_thread.is_alive())
        finally:
            # (to always make the pub. thread terminate finally)
            self.obj._publishing = False
            output_fifo_put_nowait_orig(None)

            # now the pub. thread should terminate shortly or be already terminated
            self.obj._publishing_thread.join(15.0)
            if self.obj._publishing_thread.is_alive():
                raise RuntimeError('unexpected problem: the publishing '
                                   'thread did not terminate :-/')

        self.assertEqual(output_fifo_put_nowait_mock.mock_calls, [call(None)])
Example #6
0
class AMQPHandler(logging.Handler):

    ERROR_LOGGER = 'AMQP_LOGGING_HANDLER_ERRORS'
    LOGRECORD_EXTRA_ATTRS = {
        'py_ver': '.'.join(map(str, sys.version_info)),
        'py_64bits': (sys.maxsize > 2**32),
        'py_platform': sys.platform,
        'hostname': HOSTNAME,
        'script_basename': SCRIPT_BASENAME,
    }
    LOGRECORD_KEY_MAX_LENGTH = 256

    DEFAULT_MSG_COUNT_WINDOW = 300
    DEFAULT_MSG_COUNT_MAX = 100

    DEFAULT_EXCHANGE_DECLARE_KWARGS = {'exchange_type': 'topic'}
    DEFAULT_RK_TEMPLATE = '{hostname}.{script_basename}.{levelname}.{loggername}'
    DEFAULT_PROP_KWARGS = dict(
        content_type='application/json',
        delivery_mode=1,
    )

    def __init__(self,
                 connection_params_dict,
                 exchange='logging',
                 exchange_declare_kwargs=None,
                 rk_template=None,
                 prop_kwargs=None,
                 other_pusher_kwargs=None,
                 error_logger_name=None,
                 msg_count_window=None,
                 msg_count_max=None,
                 **super_kwargs):
        if exchange_declare_kwargs is None:
            exchange_declare_kwargs = self.DEFAULT_EXCHANGE_DECLARE_KWARGS
        if rk_template is None:
            rk_template = self.DEFAULT_RK_TEMPLATE
        if prop_kwargs is None:
            prop_kwargs = self.DEFAULT_PROP_KWARGS
        if other_pusher_kwargs is None:
            other_pusher_kwargs = {}
        if error_logger_name is None:
            error_logger_name = self.ERROR_LOGGER
        if msg_count_window is None:
            msg_count_window = self.DEFAULT_MSG_COUNT_WINDOW
        if msg_count_max is None:
            msg_count_max = self.DEFAULT_MSG_COUNT_MAX

        super(AMQPHandler, self).__init__(**super_kwargs)

        self._rk_template = rk_template
        self._msg_count_window = msg_count_window
        self._msg_count_max = msg_count_max

        # error logging tools
        self._error_fifo = error_fifo = queue.Queue()
        self._error_logger_name = error_logger_name
        self._error_logging_thread = threading.Thread(
            target=self._error_logging_loop,
            kwargs=dict(error_fifo=self._error_fifo,
                        error_logger=logging.getLogger(error_logger_name)))
        self._error_logging_thread.daemon = True
        self._closing = False

        def error_callback(exc):
            try:
                exc_info = sys.exc_info()
                assert exc_info[1] is exc
                error_msg = make_condensed_debug_msg(exc_info)
                error_fifo.put_nowait(error_msg)
            finally:
                # (to break any traceback-related reference cycle)
                exc_info = exc = None  # noqa

        # pusher instance
        self._pusher = AMQPThreadedPusher(
            connection_params_dict=connection_params_dict,
            exchange=dict(exchange_declare_kwargs, exchange=exchange),
            prop_kwargs=prop_kwargs,
            serialize=self._make_record_serializer(),
            error_callback=error_callback,
            **other_pusher_kwargs)

        # start error logging co-thread
        self._error_logging_thread.start()

    @classmethod
    def _error_logging_loop(cls, error_fifo, error_logger):
        try:
            while True:
                error_msg = error_fifo.get()
                error_logger.error('%s', error_msg)
        except:
            dump_condensed_debug_msg(
                'ERROR LOGGING CO-THREAD STOPS WITH EXCEPTION!')
            raise  # traceback should be printed to sys.stderr automatically

    def _make_record_serializer(self):
        defaultdict = collections.defaultdict
        formatter = logging.Formatter()
        json_encode = json.JSONEncoder(default=reprlib.repr).encode
        record_attrs_proto = self.LOGRECORD_EXTRA_ATTRS
        record_key_max_length = self.LOGRECORD_KEY_MAX_LENGTH
        match_useless_stack_item_regex = re.compile(
            r'  File "[ \S]*/python[0-9.]+/logging/__init__\.py\w?"',
            re.ASCII,
        ).match
        # (see: https://github.com/python/cpython/blob/4f161e65a011f287227c944fad9987446644041f/Lib/logging/__init__.py#L1540)
        stack_info_preamble = 'Stack (most recent call last):\n'

        msg_count_window = self._msg_count_window
        msg_count_max = self._msg_count_max
        cur_window = None
        loggername_to_window_and_msg_to_count = defaultdict(
            lambda: (cur_window, defaultdict(lambda: -msg_count_max)))

        def _should_publish(record):
            # if, within the particular time window (window length is
            # defined as `msg_count_window`, in seconds), the number of
            # records from the particular logger that contain the same
            # `msg` (note: *not* necessarily the same `message`!)
            # exceeds the limit (defined as `msg_count_max`) --
            # further records containing that `msg` and originating from
            # that logger are skipped until *any* record from that
            # logger appears within *another* time window...
            nonlocal cur_window
            loggername = record.name
            msg = record.msg
            cur_window = record.created // msg_count_window
            window, msg_to_count = loggername_to_window_and_msg_to_count[
                loggername]
            if window != cur_window:
                # new time window for this logger
                # => attach (as the `msg_skipped_to_count` record
                #    attribute) the info about skipped messages (if
                #    any) and update/flush the state mappings
                msg_skipped_to_count = dict(
                    (m, c) for m, c in msg_to_count.items() if c > 0)
                if msg_skipped_to_count:
                    record.msg_skipped_to_count = msg_skipped_to_count
                loggername_to_window_and_msg_to_count[
                    loggername] = cur_window, msg_to_count
                msg_to_count.clear()
            msg_to_count[msg] = count = msg_to_count[msg] + 1
            if count <= 0:
                if count == 0:
                    # this is the last record (from the particular
                    # logger + containing the particular `msg`) in
                    # the current time window that is *not* skipped
                    # => so it obtains the `msg_reached_count_max`
                    #    attribute (equal to `msg_count_max` value)
                    record.msg_reached_count_max = msg_count_max
                return True
            else:
                return False

        def _try_to_extract_formatted_call_stack(record_attrs):
            # Provided by standard `logging` stuff if log method was called with `stack_info=True`:
            stack_info = record_attrs.pop('stack_info', None)
            # Provided by our `AMQPHandler.emit()` if `levelno` was at least `logging.ERROR` and
            # `stack_info` was not present:
            stack_items = record_attrs.pop('formatted_call_stack_items', None)
            if stack_items:
                del stack_items[
                    -1]  # (this item is from our `AMQPHandler.emit()`)
                while stack_items and match_useless_stack_item_regex(
                        stack_items[-1]):
                    del stack_items[-1]
                joined = ''.join(stack_items)
                # (mimicking how `stack_info` is formatted by `logging`)
                if joined.endswith('\n'):
                    joined = joined[:-1]
                formatted_call_stack = stack_info_preamble + joined
            elif stack_info:
                # (`stack_info` should already start with `stack_info_preamble`)
                formatted_call_stack = stack_info
            else:
                formatted_call_stack = None
            return formatted_call_stack

        def serialize_record(record):
            if not _should_publish(record):
                return DoNotPublish
            record_attrs = record_attrs_proto.copy()
            record_attrs.update(
                (limit_str(ascii_str(key), char_limit=record_key_max_length),
                 value) for key, value in vars(record).items())
            record_attrs['message'] = record.getMessage()
            record_attrs['asctime'] = formatter.formatTime(record)
            exc_info = record_attrs.pop('exc_info', None)
            if exc_info:
                if not record_attrs['exc_text']:
                    record_attrs['exc_text'] = formatter.formatException(
                        exc_info)
                record_attrs['exc_type_repr'] = ascii(exc_info[0])
                record_attrs['exc_ascii_str'] = ascii_str(exc_info[1])
            formatted_call_stack = _try_to_extract_formatted_call_stack(
                record_attrs)
            if formatted_call_stack:
                record_attrs['formatted_call_stack'] = formatted_call_stack
            return json_encode(record_attrs)

        return serialize_record

    def emit(self, record, _ERROR_LEVEL_NO=logging.ERROR):
        try:
            # ignore internal AMQP-handler-related error messages --
            # i.e., those logged with the handler's error logger
            # (to avoid infinite loop of message emissions...)
            if record.name == self._error_logger_name:
                return
            # (exception here ^ is hardly probable, but you never know...)
        except RecursionError:  # see: https://bugs.python.org/issue36272 XXX: is it really needed?
            raise
        except Exception:
            super().handleError(record)
            # (better trigger the same exception again than continue
            # if the following condition is true)
            if record.name == self._error_logger_name:
                return

        try:
            if record.levelno >= _ERROR_LEVEL_NO and not record.stack_info:
                record.formatted_call_stack_items = traceback.format_stack()
            routing_key = self._rk_template.format(
                hostname=HOSTNAME,
                script_basename=SCRIPT_BASENAME,
                levelname=record.levelname,
                loggername=record.name)
            try:
                self._pusher.push(record, routing_key)
            except ValueError:
                if not self._closing:
                    raise
        except RecursionError:  # see: https://bugs.python.org/issue36272 XXX: is it really needed?
            raise
        except Exception:
            self.handleError(record)

    def close(self):
        # typically, this method is called at interpreter exit
        # (by logging.shutdown() which is always registered with
        # atexit.register() machinery)
        try:
            try:
                try:
                    super(AMQPHandler, self).close()
                finally:
                    self._closing = True
                    self._pusher.shutdown()
            except:
                dump_condensed_debug_msg(
                    'EXCEPTION DURING EXECUTION OF close() OF THE AMQP LOGGING HANDLER!'
                )
                raise
        except Exception as exc:
            self._error_fifo.put(exc)
        finally:
            # (to break any traceback-related reference cycle)
            self = None  # noqa

    def handleError(self, record):
        try:
            exc = sys.exc_info()[1]
            self._error_fifo.put(exc)
        except RecursionError:  # see: https://bugs.python.org/issue36272 XXX: is it really needed?
            raise
        except Exception:
            super().handleError(record)
        else:
            super().handleError(record)
        finally:
            # (to break any traceback-related reference cycle)
            exc = self = None  # noqa