def setUp(self, event_loop): self.connection_id_1 = 7 self.connection_id_2 = 8 self.connection_id_3 = 13 self.connection_id_4 = 14 self.ss_request_id_1 = 17 self.ss_request_id_2 = 19 self.ss_request_id_3 = 20 self.ss_request_id_4 = 21 self.golem_message = Ping() self.golem_message_from_ss = Ping() self.concent_request_id = 777 self.different_concent_request_id = 777 self.message_tracker = OrderedDict([ (self.ss_request_id_1, MessageTrackerItem(self.concent_request_id, self.connection_id_1, self.golem_message, FROZEN_TIMESTAMP)), (self.ss_request_id_2, MessageTrackerItem(self.concent_request_id, self.connection_id_2, self.golem_message, FROZEN_TIMESTAMP)), (self.ss_request_id_3, MessageTrackerItem(self.different_concent_request_id, self.connection_id_3, self.golem_message, FROZEN_TIMESTAMP)), (self.ss_request_id_4, MessageTrackerItem(self.different_concent_request_id, self.connection_id_4, self.golem_message, FROZEN_TIMESTAMP)), ]) self.response_queue_pool = QueuePool( { self.connection_id_1: Queue(loop=event_loop), self.connection_id_4: Queue(loop=event_loop), } )
def test_that_queue_pool_is_created_with_default_params(self, _mocked_get_logger, mocked_asyncio): mocked_asyncio.get_event_loop.return_value = sentinel.loop queue_pool = QueuePool() assert_that(queue_pool).is_equal_to({}) assert_that(queue_pool.loop).is_equal_to(sentinel.loop) assert_that((queue_pool.logger)).is_equal_to(sentinel.logger)
def test_that_queue_pool_is_created_with_given_params(self, event_loop): initial_data = {1: asyncio.Queue(loop=event_loop)} loop = event_loop logger = create_autospec(spec=Logger, spec_set=True) queue_pool = QueuePool(initial_data, loop, logger) assert_that(queue_pool).is_equal_to(initial_data) assert_that(queue_pool.loop).is_equal_to(loop) assert_that(queue_pool.logger).is_equal_to(logger)
def setUp(self, event_loop): self.first_index = 1 self.second_index = 2 ping_message = Ping() resposne_queue_items = [ ResponseQueueItem(ping_message, 777, get_current_utc_timestamp()), ResponseQueueItem(ping_message, 888, get_current_utc_timestamp()), ResponseQueueItem(ping_message, 1001, get_current_utc_timestamp()), ResponseQueueItem(ping_message, 1007, get_current_utc_timestamp()), ] first_queue = asyncio.Queue(loop=event_loop) second_queue = asyncio.Queue(loop=event_loop) self.number_of_items_in_first_queue = 3 event_loop.run_until_complete(self._populate_queues(first_queue, resposne_queue_items, second_queue)) self.initial_dict = { self.first_index: first_queue, self.second_index: second_queue, } self.logger_mock = create_autospec(spec=Logger, spec_set=True) self.queue_pool = QueuePool(self.initial_dict, event_loop, self.logger_mock)
def __init__( self, bind_address: Optional[str]=None, internal_port: Optional[int]=None, external_port: Optional[int]=None, loop: Optional[BaseEventLoop]=None ) -> None: self._bind_address = bind_address if bind_address is not None else LOCALHOST_IP self._internal_port = internal_port if internal_port is not None else DEFAULT_INTERNAL_PORT self._external_port = external_port if external_port is not None else DEFAULT_EXTERNAL_PORT self._server_for_concent: Optional[BaseEventLoop] = None self._server_for_signing_service = None self._is_signing_service_connection_active = False self._loop = loop if loop is not None else asyncio.get_event_loop() self._connection_id = 0 self._request_queue: asyncio.Queue = asyncio.Queue(loop=self._loop) self._response_queue_pool = QueuePool(loop=self._loop) self._message_tracker: OrderedDict = OrderedDict() self._ss_connection_candidates: List[Tuple[asyncio.Task, asyncio.StreamWriter]] = [] # Handle shutdown signal. self._loop.add_signal_handler(signal.SIGTERM, self._terminate_connections)
def setUp(self, event_loop): self.mocked_writer = prepare_mocked_writer() self.message_tracker = OrderedDict({}) self.golem_message = Ping() sign_message(self.golem_message, CONCENT_PRIVATE_KEY) self.connection_id = 4 self.request_id = 888 self.queue = Queue(loop=event_loop) self.queue_pool = QueuePool( {self.connection_id: Queue(loop=event_loop)}, loop=event_loop, ) self.signing_service_request_id = 1 self.request_queue_item = RequestQueueItem( self.connection_id, self.request_id, self.golem_message, FROZEN_TIMESTAMP )
async def test_that_when_connection_id_no_longer_exists_corresponding_item_is_dropped(self, event_loop): with patch("middleman.asynchronous_operations.logger") as mocked_logger: with override_settings( CONCENT_PRIVATE_KEY=CONCENT_PRIVATE_KEY, CONCENT_PUBLIC_KEY=CONCENT_PUBLIC_KEY, ): await self.queue.put(self.request_queue_item) consumer_task = event_loop.create_task( request_consumer( self.queue, QueuePool({}), self.message_tracker, self.mocked_writer ) ) await self.queue.join() consumer_task.cancel() assert_that(self.message_tracker).is_empty() mocked_logger.info.assert_called_once_with( f"No matching queue for connection id: {self.request_queue_item.connection_id}" ) self.mocked_writer.assert_not_called()
class TestQueuePoolOperations: @pytest.fixture(autouse=True) def setUp(self, event_loop): self.first_index = 1 self.second_index = 2 ping_message = Ping() resposne_queue_items = [ ResponseQueueItem(ping_message, 777, get_current_utc_timestamp()), ResponseQueueItem(ping_message, 888, get_current_utc_timestamp()), ResponseQueueItem(ping_message, 1001, get_current_utc_timestamp()), ResponseQueueItem(ping_message, 1007, get_current_utc_timestamp()), ] first_queue = asyncio.Queue(loop=event_loop) second_queue = asyncio.Queue(loop=event_loop) self.number_of_items_in_first_queue = 3 event_loop.run_until_complete(self._populate_queues(first_queue, resposne_queue_items, second_queue)) self.initial_dict = { self.first_index: first_queue, self.second_index: second_queue, } self.logger_mock = create_autospec(spec=Logger, spec_set=True) self.queue_pool = QueuePool(self.initial_dict, event_loop, self.logger_mock) def test_that_when_already_existing_connection_is_added_exception_is_thrown(self): with pytest.raises(KeyError): self.queue_pool[1] = asyncio.Queue() def test_that_when_already_existing_connection_is_added_during_update_exception_is_thrown(self): with pytest.raises(KeyError): self.queue_pool.update(self.initial_dict) def test_that_deleting_mapping_with_non_empty_queue_logs_untretrived_queue_items(self, event_loop): async def inner(): with freeze_time("2018-09-01 11:48:04"): del self.queue_pool[self.first_index] await asyncio.sleep(0.0001) assert_that(self.queue_pool.keys()).does_not_contain(self.first_index) assert_that(self.logger_mock.info.call_count).is_equal_to(self.number_of_items_in_first_queue) event_loop.run_until_complete(inner()) def test_that_popping_mapping_with_non_empty_queue_logs_unretrieved_queue_items(self, event_loop): async def inner(): with freeze_time("2018-09-01 11:48:04"): retrived_item = self.queue_pool.pop(self.second_index) await asyncio.sleep(0.0001) assert_that(retrived_item.empty()).is_true() assert_that(self.queue_pool.keys()).does_not_contain(self.second_index) assert_that(self.logger_mock.info.call_count).is_equal_to(1) event_loop.run_until_complete(inner()) def test_that_using_popitem_on_mapping_with_non_empty_queue_logs_unretrieved_queue_items(self, event_loop): async def inner(): with freeze_time("2018-09-01 11:48:04"): index, queue = self.queue_pool.popitem() await asyncio.sleep(0.0001) assert_that(queue.empty()).is_true() assert_that(self.queue_pool.keys()).does_not_contain(index) assert_that(self.logger_mock.info.call_count).is_equal_to(1) event_loop.run_until_complete(inner()) async def _populate_queues(self, first_queue, resposne_queue_items, second_queue): for item in resposne_queue_items[:self.number_of_items_in_first_queue]: await first_queue.put(item) for item in resposne_queue_items[self.number_of_items_in_first_queue:]: await second_queue.put(item)
class MiddleMan: def __init__( self, bind_address: Optional[str]=None, internal_port: Optional[int]=None, external_port: Optional[int]=None, loop: Optional[BaseEventLoop]=None ) -> None: self._bind_address = bind_address if bind_address is not None else LOCALHOST_IP self._internal_port = internal_port if internal_port is not None else DEFAULT_INTERNAL_PORT self._external_port = external_port if external_port is not None else DEFAULT_EXTERNAL_PORT self._server_for_concent: Optional[BaseEventLoop] = None self._server_for_signing_service = None self._is_signing_service_connection_active = False self._loop = loop if loop is not None else asyncio.get_event_loop() self._connection_id = 0 self._request_queue: asyncio.Queue = asyncio.Queue(loop=self._loop) self._response_queue_pool = QueuePool(loop=self._loop) self._message_tracker: OrderedDict = OrderedDict() self._ss_connection_candidates: List[Tuple[asyncio.Task, asyncio.StreamWriter]] = [] # Handle shutdown signal. self._loop.add_signal_handler(signal.SIGTERM, self._terminate_connections) def run(self) -> None: """ It is a wrapper layer over "main loop" which handles exceptions """ try: self._run() except KeyboardInterrupt: # if CTRl-C is pressed before server starts, it will intercepted here (exception will not be reported to Sentry) logger.info("Ctrl-C has been pressed.") logger.info("Exiting.") except SystemExit: # system exit should be reraised (returned) to OS raise except Exception as exception: # pylint: disable=broad-except # All other (unhandled) exceptions should be reported to Sentry via crash logger logger.exception(str(exception)) crash_logger.error( f"Exception occurred: {exception}, Traceback: {traceback.format_exc()}" ) def _run(self) -> None: """ Main functionality is implemented here - start of the server and waiting for and handling incoming connections """ try: # start MiddleMan server logger.info("Starting MiddleMan") self._start_middleman() except OSError as exception: logger.error( f"Exception <OSError> occurred while starting MiddleMan server for Concent: {str(exception)}" ) exit(ERROR_ADDRESS_ALREADY_IN_USE) try: # Serve requests until Ctrl+C is pressed logger.info( 'MiddleMan is serving for Concent on {}'.format( self._server_for_concent.sockets[0].getsockname() # type: ignore ) ) logger.info( 'MiddleMan is serving for Signing Service on {}'.format( self._server_for_signing_service.sockets[0].getsockname() # type: ignore ) ) self._run_forever() except KeyboardInterrupt: logger.info("Ctrl-C has been pressed.") # Close the server logger.info("Server is closing...") self._close_middleman() logger.info("Closed.") exit() def _run_forever(self) -> None: self._loop.run_forever() def _start_middleman(self) -> None: concent_server_coroutine = asyncio.start_server( self._handle_concent_connection, self._bind_address, self._internal_port, loop=self._loop, limit=MAXIMUM_FRAME_LENGTH ) self._server_for_concent = self._loop.run_until_complete(concent_server_coroutine) service_server_coroutine = asyncio.start_server( self._handle_service_connection, self._bind_address, self._external_port, loop=self._loop, limit=MAXIMUM_FRAME_LENGTH ) self._server_for_signing_service = self._loop.run_until_complete(service_server_coroutine) def _close_middleman(self) -> None: self._server_for_concent.close() # type: ignore self._loop.run_until_complete(self._server_for_concent.wait_closed()) # type: ignore self._server_for_signing_service.close() # type: ignore self._loop.run_until_complete(self._server_for_signing_service.wait_closed()) # type: ignore self._cancel_pending_tasks(asyncio.Task.all_tasks(), await_cancellation=True) self._loop.close() async def _handle_concent_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: tasks = [] response_queue: asyncio.Queue = asyncio.Queue(loop=self._loop) connection_id = self._connection_id = (self._connection_id + 1) % CONNECTION_COUNTER_LIMIT self._response_queue_pool[connection_id] = response_queue try: request_producer_task = self._loop.create_task( request_producer(self._request_queue, response_queue, reader, connection_id) ) response_consumer_task = self._loop.create_task( response_consumer(response_queue, writer, connection_id) ) tasks.append(request_producer_task) tasks.append(response_consumer_task) await request_producer_task # 1. wait until producer task finishes (Concent will sent no more messages) await asyncio.sleep(PROCESSING_TIMEOUT) # 2. give some time to process items already put to request queue await response_queue.join() # 3. wait until items from response queue are processed (sent back to Concent) response_consumer_task.cancel() except asyncio.CancelledError: # CancelledError shall not be treated as crash and logged in Sentry. It is only logged as info logger.debug(f"CancelledError in Concent connection handler. Connection ID: {connection_id}.") raise except Exception as exception: # pylint: disable=broad-except crash_logger.error( f"Exception occurred: {exception}, Traceback: {traceback.format_exc()}" ) raise finally: logger.debug(f"Canceling tasks upon exit of Concent connection handler. Number of tasks to cancel: {len(tasks)}.") # regardless of exception's occurrence, all unfinished tasks should be cancelled # if exceptions occurs, producer task might need cancelling as well self._cancel_pending_tasks(tasks) # remove response queue from the pool removed_queue: Optional[asyncio.Queue] = self._response_queue_pool.pop(connection_id, None) if removed_queue is None: logger.warning(f"Response queue for connection ID: {connection_id} has been already removed") else: logger.info(f"Removing response queue for connection ID: {connection_id}.") async def _handle_service_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: if self._is_signing_service_connection_active: writer.close() else: tasks: list = [] try: successful = await self._authenticate_signing_service(reader, writer) if not successful: writer.close() return request_consumer_task = self._loop.create_task( request_consumer( self._request_queue, self._response_queue_pool, self._message_tracker, writer ) ) tasks.append(request_consumer_task) response_producer_task = self._loop.create_task( response_producer( self._response_queue_pool, reader, self._message_tracker ) ) tasks.append(response_producer_task) heartbeat_producer_task = self._loop.create_task( heartbeat_producer( writer, ) ) tasks.append(heartbeat_producer_task) done_tasks, pending_tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for future in pending_tasks: future.cancel() for future in done_tasks: exception_from_task = future.exception() if exception_from_task is not None: raise exception_from_task except asyncio.CancelledError: # CancelledError shall not be treated as crash and logged in Sentry. It is only logged as info logger.debug(f"CancelledError in Signing Service connection handler.") except Exception as exception: # pylint: disable=broad-except crash_logger.error( f"Exception occurred in Signing Service connection handler: {exception}, Traceback: {traceback.format_exc()}" ) raise finally: logger.debug(f"Canceling tasks upon exit of Signing Service connection handler. Number of tasks to cancel: {len(tasks)}.") # cancel all tasks - if task is already done/cancelled it makes no harm self._cancel_pending_tasks(tasks) self._is_signing_service_connection_active = False def _terminate_connections(self) -> None: logger.info('SIGTERM received - closing connections and exiting.') self._loop.stop() def _cancel_pending_tasks(self, tasks: Iterable[asyncio.Task], await_cancellation: bool = False) -> None: for task in tasks: logger.debug(f'Task will be canceled. Task: {task}') task.cancel() if await_cancellation: # Now we should await task to execute it's cancellation. # Cancelled task raises asyncio.CancelledError that we can suppress: with suppress(asyncio.CancelledError): self._loop.run_until_complete(task) async def _authenticate_signing_service(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> bool: logger.info("Signing Service candidate has connected, authenticating...") authentication_task = self._loop.create_task(is_authenticated(reader, writer)) index = len(self._ss_connection_candidates) self._ss_connection_candidates.append((authentication_task, writer)) await authentication_task self._ss_connection_candidates.pop(index) is_signing_service_authenticated = authentication_task.result() if is_signing_service_authenticated: logger.info("Authentication successful: Signing Service has connected.") self._is_signing_service_connection_active = True self._abort_ongoing_authentication() else: logger.info("Authentication unsuccessful, closing connection with candidate.") return is_signing_service_authenticated def _abort_ongoing_authentication(self) -> None: counter = 0 length = len(self._ss_connection_candidates) for task, writer in self._ss_connection_candidates: logger.info(f"Canceling task {counter}/{length}...") task.cancel() writer.close() logger.info("Canceled!") counter += 1 self._ss_connection_candidates.clear()