예제 #1
0
 def test_send_sync_rpc_response(self):
     expected = SyncRPCResponse(reqId=123, respBody=self._expected_resp)
     self._sync_rpc_client._response_queue.put(expected)
     res = self._sync_rpc_client.send_sync_rpc_response()
     actual = next(res)
     self.assertEqual(expected, actual)
     expected = SyncRPCResponse(heartBeat=True)
     actual = next(res)
     self.assertEqual(expected, actual)
예제 #2
0
    async def _await_gateway_response(self, client, stream_id, body, req_id,
                                      response_queue, conn_closed_table):
        await client.send_data(stream_id, body, end_stream=True)

        resp_headers = await client.recv_response(stream_id)
        status = self._get_resp_status(resp_headers)

        curr_payload = await self._read_stream(client, stream_id, req_id,
                                               response_queue,
                                               conn_closed_table)
        next_payload = await self._read_stream(client, stream_id, req_id,
                                               response_queue,
                                               conn_closed_table)

        while True:
            trailers = await client.recv_trailers(stream_id) \
                if not next_payload else []
            headers = self._get_resp_headers(resp_headers, trailers)
            res = GatewayResponse(status=status,
                                  headers=headers,
                                  payload=curr_payload)
            response_queue.put(
                SyncRPCResponse(heartBeat=False, reqId=req_id, respBody=res))
            if not next_payload:
                break

            curr_payload = next_payload
            next_payload = await self._read_stream(client, stream_id, req_id,
                                                   response_queue,
                                                   conn_closed_table)
예제 #3
0
    def test_send_request_done(self):
        """
        Test if send_request_done puts the right SyncRPCResponse in
        response_queue based on the result of future.
        Returns: None

        """
        self.assertTrue(self._sync_rpc_client._response_queue.empty())
        req_id = 356
        future = MockFuture(is_error=False,
                            expected_result=self._expected_resp,
                            expected_err_msg="")
        # future has result, send_request_done should enqueue a SyncRPCResponse
        # with the req_id, and the expected_resp set in MockFuture
        self._sync_rpc_client.send_request_done(req_id, future)
        self.assertFalse(self._sync_rpc_client._response_queue.empty())
        res = self._sync_rpc_client._response_queue.get(block=False)
        self.assertEqual(
            res,
            SyncRPCResponse(reqId=req_id,
                            respBody=self._expected_resp,
                            heartBeat=False))

        self.assertTrue(self._sync_rpc_client._response_queue.empty())
        req_id = 234
        expected_err_resp = GatewayResponse(err=self._expected_err_msg)
        future = MockFuture(is_error=True,
                            expected_result=GatewayResponse(),
                            expected_err_msg=self._expected_err_msg)
        self._sync_rpc_client.send_request_done(req_id, future)

        res = self._sync_rpc_client._response_queue.get(block=False)
        self.assertEqual(
            res,
            SyncRPCResponse(reqId=req_id,
                            respBody=expected_err_resp,
                            heartBeat=False))
예제 #4
0
    def send_request_done(self, req_id, future):
        """
        A future that has a GatewayResponse is done. Check if a exception is
        raised. If so, log the error and enqueue an empty SyncRPCResponse.
        Else, enqueue a SyncRPCResponse that contains the GatewayResponse that
        became available in the future.
        Args:
            req_id: request id that's associated with the response
            future: A future that contains a GatewayResponse that is done.

        Returns: None

        """
        err = future.exception()
        if err:
            logging.error("[SyncRPC] Forward to control proxy error: %s", err)
            self._response_queue.put(
                SyncRPCResponse(heartBeat=False,
                                reqId=req_id,
                                respBody=GatewayResponse(err=str(err))))
        else:
            res = future.result()
            self._response_queue.put(
                SyncRPCResponse(heartBeat=False, reqId=req_id, respBody=res))
예제 #5
0
 async def try_read_stream():
     while True:
         try:
             payload = await asyncio.wait_for(
                 client.read_stream(stream_id), timeout=10.0)
             if conn_closed_table.get(req_id, False):
                 raise ConnectionAbortedError
             return payload
         except asyncio.TimeoutError:
             if conn_closed_table.get(req_id, False):
                 raise ConnectionAbortedError
             response_queue.put(
                 SyncRPCResponse(
                     heartBeat=False,
                     reqId=req_id,
                     respBody=GatewayResponse(keepConnActive=True)))
예제 #6
0
 def send_sync_rpc_response(self):
     """
     Retrieve SyncRPCResponse from queue. If no response is available yet,
     block for at most response_timeout seconds, and send a heartBeat if
     timeout.
     Returns: A generator of SyncRPCResponse
     """
     while True:
         try:
             resp = self._response_queue.get(block=True,
                                             timeout=self._response_timeout)
             yield resp
         except queue.Empty:
             # response_queue is empty, send heartbeat
             # as the function itself has no knowledge on when it's
             # the first time it's called
             # this heartbeat response could be periodically called
             logging.debug("[SyncRPC] Sending heartbeat")
             yield SyncRPCResponse(heartBeat=True)
예제 #7
0
    async def send(self, gateway_request, req_id, sync_rpc_response_queue,
                   conn_closed_table):
        """
        Forwards the given request to the service provided
        in :authority and awaits a response. If a exception is
        raised, log the error and enqueue an empty SyncRPCResponse.
        Else, enqueue SyncRPCResponse(s) that contains the GatewayResponse.

        Args:
            gateway_request: gateway_request: A GatewayRequest that is
        defined in the sync_rpc_service.proto. It has fields gwId, authority,
        path, headers, and payload.
            req_id: request id that's associated with the response
            sync_rpc_response_queue: the response queue that responses
        will be put in
            conn_closed_table: table that maps req ids to if the conn is closed

        Returns: None.

        """
        client = await self._get_client(gateway_request.authority)

        # Small hack to set PingReceived to no-op because the log gets spammed
        # with KeyError messages since aioh2 doesn't have a handler for
        # PingReceived. Remove if future versions support it.
        # pylint: disable=protected-access
        if hasattr(h2.events, "PingReceived"):
            # Need the hasattr here because some older versions of h2 may not
            # have the PingReceived event
            client._event_handlers[h2.events.PingReceived] = lambda _: None
        # pylint: enable=protected-access

        if req_id in self._connection_table:
            logging.error(
                "[SyncRPC] proxy_client is already handling "
                "request ID %s", req_id)
            sync_rpc_response_queue.put(
                SyncRPCResponse(
                    heartBeat=False,
                    reqId=req_id,
                    respBody=GatewayResponse(
                        err=str("request ID {} is already being handled".
                                format(req_id)))))
            client.close_connection()
            return
        self._connection_table[req_id] = client

        try:
            await client.wait_functional()
            req_headers = self._get_req_headers(gateway_request.headers,
                                                gateway_request.path,
                                                gateway_request.authority)
            body = gateway_request.payload
            stream_id = await client.start_request(req_headers)
            await self._await_gateway_response(client, stream_id, body, req_id,
                                               sync_rpc_response_queue,
                                               conn_closed_table)
        except ConnectionAbortedError:
            logging.error("[SyncRPC] proxy_client connection "
                          "terminated by cloud")
        except Exception as e:  # pylint: disable=broad-except
            logging.error("[SyncRPC] Exception in proxy_client: %s", e)
            sync_rpc_response_queue.put(
                SyncRPCResponse(heartBeat=False,
                                reqId=req_id,
                                respBody=GatewayResponse(err=str(e))))
        finally:
            del self._connection_table[req_id]
            client.close_connection()