async def _await_gateway_response(self, client, stream_id, body, req_id, response_queue, conn_closed_table): await client.send_data(stream_id, body, end_stream=True) resp_headers = await client.recv_response(stream_id) status = self._get_resp_status(resp_headers) curr_payload = await self._read_stream(client, stream_id, req_id, response_queue, conn_closed_table) next_payload = await self._read_stream(client, stream_id, req_id, response_queue, conn_closed_table) while True: trailers = await client.recv_trailers(stream_id) \ if not next_payload else [] headers = self._get_resp_headers(resp_headers, trailers) res = GatewayResponse(status=status, headers=headers, payload=curr_payload) response_queue.put( SyncRPCResponse(heartBeat=False, reqId=req_id, respBody=res)) if not next_payload: break curr_payload = next_payload next_payload = await self._read_stream(client, stream_id, req_id, response_queue, conn_closed_table)
def setUp(self): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) self._loop = loop self._sync_rpc_client = SyncRPCClient(loop=loop, response_timeout=3) self._sync_rpc_client._conn_closed_table = {12345: False} ServiceRegistry.add_service('test', '0.0.0.0', 0) ServiceRegistry._PROXY_CONFIG = { 'local_port': 2345, 'cloud_address': 'test', 'proxy_cloud_connections': True } self._req_body = GatewayRequest(gwId="test id", authority='mobility', path='/magma.MobilityService' '/ListAddedIPv4Blocks', headers={ 'te': 'trailers', 'content-type': 'application/grpc', 'user-agent': 'grpc-python/1.4.0', 'grpc-accept-encoding': 'identity' }, payload=bytes.fromhex('0000000000')) self._expected_resp = GatewayResponse(status="400", headers={"test_key": "test_val"}, payload=b'\x00' b'\x00\x00\x00\n\n\x08') self._expected_err_msg = "test error"
async def _await_gateway_response(self, client, stream_id, body): await client.send_data(stream_id, body, end_stream=True) resp_headers = await client.recv_response(stream_id) status = self._get_resp_status(resp_headers) payload = await client.read_stream(stream_id, -1) trailers = await client.recv_trailers(stream_id) headers = self._get_resp_headers(resp_headers, trailers) return GatewayResponse(status=status, headers=headers, payload=payload)
def test_send_request_done(self): """ Test if send_request_done puts the right SyncRPCResponse in response_queue based on the result of future. Returns: None """ self.assertTrue(self._sync_rpc_client._response_queue.empty()) req_id = 356 future = MockFuture(is_error=False, expected_result=self._expected_resp, expected_err_msg="") # future has result, send_request_done should enqueue a SyncRPCResponse # with the req_id, and the expected_resp set in MockFuture self._sync_rpc_client.send_request_done(req_id, future) self.assertFalse(self._sync_rpc_client._response_queue.empty()) res = self._sync_rpc_client._response_queue.get(block=False) self.assertEqual( res, SyncRPCResponse(reqId=req_id, respBody=self._expected_resp, heartBeat=False)) self.assertTrue(self._sync_rpc_client._response_queue.empty()) req_id = 234 expected_err_resp = GatewayResponse(err=self._expected_err_msg) future = MockFuture(is_error=True, expected_result=GatewayResponse(), expected_err_msg=self._expected_err_msg) self._sync_rpc_client.send_request_done(req_id, future) res = self._sync_rpc_client._response_queue.get(block=False) self.assertEqual( res, SyncRPCResponse(reqId=req_id, respBody=expected_err_resp, heartBeat=False))
async def try_read_stream(): while True: try: payload = await asyncio.wait_for( client.read_stream(stream_id), timeout=10.0) if conn_closed_table.get(req_id, False): raise ConnectionAbortedError return payload except asyncio.TimeoutError: if conn_closed_table.get(req_id, False): raise ConnectionAbortedError response_queue.put( SyncRPCResponse( heartBeat=False, reqId=req_id, respBody=GatewayResponse(keepConnActive=True)))
def send_request_done(self, req_id, future): """ A future that has a GatewayResponse is done. Check if a exception is raised. If so, log the error and enqueue an empty SyncRPCResponse. Else, enqueue a SyncRPCResponse that contains the GatewayResponse that became available in the future. Args: req_id: request id that's associated with the response future: A future that contains a GatewayResponse that is done. Returns: None """ err = future.exception() if err: logging.error("[SyncRPC] Forward to control proxy error: %s", err) self._response_queue.put( SyncRPCResponse(heartBeat=False, reqId=req_id, respBody=GatewayResponse(err=str(err)))) else: res = future.result() self._response_queue.put( SyncRPCResponse(heartBeat=False, reqId=req_id, respBody=res))
async def send(self, gateway_request, req_id, sync_rpc_response_queue, conn_closed_table): """ Forwards the given request to the service provided in :authority and awaits a response. If a exception is raised, log the error and enqueue an empty SyncRPCResponse. Else, enqueue SyncRPCResponse(s) that contains the GatewayResponse. Args: gateway_request: gateway_request: A GatewayRequest that is defined in the sync_rpc_service.proto. It has fields gwId, authority, path, headers, and payload. req_id: request id that's associated with the response sync_rpc_response_queue: the response queue that responses will be put in conn_closed_table: table that maps req ids to if the conn is closed Returns: None. """ client = await self._get_client(gateway_request.authority) # Small hack to set PingReceived to no-op because the log gets spammed # with KeyError messages since aioh2 doesn't have a handler for # PingReceived. Remove if future versions support it. # pylint: disable=protected-access if hasattr(h2.events, "PingReceived"): # Need the hasattr here because some older versions of h2 may not # have the PingReceived event client._event_handlers[h2.events.PingReceived] = lambda _: None # pylint: enable=protected-access if req_id in self._connection_table: logging.error( "[SyncRPC] proxy_client is already handling " "request ID %s", req_id) sync_rpc_response_queue.put( SyncRPCResponse( heartBeat=False, reqId=req_id, respBody=GatewayResponse( err=str("request ID {} is already being handled". format(req_id))))) client.close_connection() return self._connection_table[req_id] = client try: await client.wait_functional() req_headers = self._get_req_headers(gateway_request.headers, gateway_request.path, gateway_request.authority) body = gateway_request.payload stream_id = await client.start_request(req_headers) await self._await_gateway_response(client, stream_id, body, req_id, sync_rpc_response_queue, conn_closed_table) except ConnectionAbortedError: logging.error("[SyncRPC] proxy_client connection " "terminated by cloud") except Exception as e: # pylint: disable=broad-except logging.error("[SyncRPC] Exception in proxy_client: %s", e) sync_rpc_response_queue.put( SyncRPCResponse(heartBeat=False, reqId=req_id, respBody=GatewayResponse(err=str(e)))) finally: del self._connection_table[req_id] client.close_connection()