def test_status(): r = DataRequest() r.docs.extend([Document()]) r.add_exception(ValueError('intentional_error')) byte_array = DataRequestProto.SerializeToString(r) deserialized_request = DataRequestProto.FromString(byte_array) assert not deserialized_request.is_decompressed assert deserialized_request.status.code == jina_pb2.StatusProto.ERROR assert deserialized_request.is_decompressed
async def task_wrapper(): import random await asyncio.sleep(1 / (random.randint(1, 3) * 10)) if requests[0].is_decompressed: return ( DataRequest(request=requests[0].proto.SerializePartialToString()), {}, ) else: return DataRequest(request=requests[0].buffer), {}
def _new_data_request(endpoint, target, parameters): req = DataRequest() # set up header if endpoint: req.header.exec_endpoint = endpoint if target: req.header.target_executor = target # add parameters field if parameters: req.parameters = parameters return req
def test_pprint_routes(capfd): result = [] r = jina_pb2.RouteProto() r.status.code = jina_pb2.StatusProto.ERROR r.status.exception.stacks.extend(['r1\nline1', 'r2\nline2']) result.append(r) r = jina_pb2.RouteProto() r.status.code = jina_pb2.StatusProto.ERROR_CHAINED r.status.exception.stacks.extend(['line1', 'line2']) result.append(r) r = jina_pb2.RouteProto() r.status.code = jina_pb2.StatusProto.SUCCESS result.append(r) rr = DataRequest() rr.routes.extend(result) pprint_routes(rr) out, err = capfd.readouterr() assert '⚪' in out assert '🟢' in out assert 'Executor' in out assert 'Time' in out assert 'Exception' in out assert 'r1' in out assert 'line1r2' in out assert 'line2' in out assert 'line1line2' in out
def _yield_data_request(): req = DataRequest() req.header.request_id = random_identity() da = DocumentArray() da.append(Document()) req.data.docs = da return req
def _get_sync_requests_iterator(num_requests): for i in range(num_requests): req = DataRequest() req.header.request_id = random_identity() da = DocumentArray() da.append(Document()) req.data.docs = da yield req
def FromString(x: bytes): """ # noqa: DAR101 # noqa: DAR102 # noqa: DAR201 """ return DataRequest(x)
def FromString(x: bytes): """ # noqa: DAR101 # noqa: DAR102 # noqa: DAR201 """ rlp = jina_pb2.DataRequestListProto() rlp.ParseFromString(x) return [DataRequest.from_proto(request) for request in rlp.requests]
def FromString(x: bytes): """ # noqa: DAR101 # noqa: DAR102 # noqa: DAR201 """ os.environ['JINA_GRPC_RECV_BYTES'] = str( len(x) + int(os.environ.get('JINA_GRPC_RECV_BYTES', 0))) return DataRequest(x)
async def recv_message(self) -> 'DataRequest': """Receive messages in bytes from server and convert to `DataRequest` ..note:: aiohttp allows only one task which can `receive` concurrently. we need to make sure we don't create multiple tasks with `recv_message` :yield: response objects received from server """ async for response in self.response_iter: yield DataRequest(response.data)
async def test_http_clientlet(): from jina.helper import random_port port = random_port() with Flow(port_expose=port, protocol='http').add(): async with HTTPClientlet(url=f'http://localhost:{port}/post', logger=logger) as iolet: request = _new_data_request('/', None, {'a': 'b'}) r = await iolet.send_message(request) response = DataRequest(await r.json()) assert response.header.exec_endpoint == '/' assert response.parameters == {'a': 'b'}
def test_lazy_serialization(): doc_count = 1000 r = DataRequest() da = r.docs da.extend([Document(text='534534534er5yr5y645745675675675345')] * doc_count) r.data.docs = da byte_array = DataRequestProto.SerializeToString(r) deserialized_request = DataRequestProto.FromString(byte_array) assert not deserialized_request.is_decompressed assert len(deserialized_request.docs) == doc_count assert deserialized_request.docs == r.docs assert deserialized_request.is_decompressed
async def bytes_sending_client(): async with aiohttp.ClientSession() as session: async with session.ws_connect( f'ws://localhost:{GATEWAY_PORT}/', protocols=('bytes', ), ) as ws: for da in input_da_gen(): request = bytes_requestify(da) await ws.send_bytes(request) response = await ws.receive_bytes() assert isinstance(response, bytes) dict_response = DataRequest(response).to_dict() assert dict_response['header']['exec_endpoint'] == '/foo' assert len(dict_response['data']) == INPUT_DA_LEN for doc in dict_response['data']: assert doc['text'] == f'{doc["id"]} is fooed!'
async def req_iter(): async for request in manager.iter(websocket): if isinstance(request, dict): if request == {}: break else: # NOTE: Helps in converting camelCase to snake_case req_generator_input = JinaEndpointRequestModel(**request).dict() req_generator_input['data_type'] = DataInputType.DICT if request['data'] is not None and 'docs' in request['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs' ] # you can't do `yield from` inside an async function for data_request in request_generator(**req_generator_input): yield data_request elif isinstance(request, bytes): if request == bytes(True): break else: yield DataRequest(request)
def test_as_pb_object(req): request = DataRequest(request=None) assert request.proto
def test_init(req): assert DataRequest(request=None) assert DataRequest(request=req) assert DataRequest(request=MessageToDict(req)) assert DataRequest(request=MessageToJson(req))
async def send(self, websocket: WebSocket, data: DataRequest) -> None: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: return await websocket.send_json(data.to_dict(), mode='text') elif subprotocol == WebsocketSubProtocols.BYTES: return await websocket.send_bytes(data.to_bytes())
def test_init_fail(): with pytest.raises(BadRequestType): DataRequest(request=5)
def bytes_requestify(da: DocumentArray, exec_endpoint='/foo'): r = DataRequest() r._pb_body.header.exec_endpoint = exec_endpoint r.data.docs_bytes = da.to_bytes() return r.to_bytes()
async def req_iter(): async for request_bytes in websocket.iter_bytes(): if request_bytes == bytes(True): break yield DataRequest(request_bytes)
async def _get_results( self, inputs: 'InputType', on_done: 'CallbackFnType', on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, **kwargs, ): """ :param inputs: the callable :param on_done: the callback for on_done :param on_error: the callback for on_error :param on_always: the callback for on_always :param kwargs: kwargs for _get_task_name and _get_requests :yields: generator over results """ with ImportExtensions(required=True): import aiohttp self.inputs = inputs request_iterator = self._get_requests(**kwargs) async with AsyncExitStack() as stack: try: cm1 = (ProgressBar(total_length=self._inputs_length) if self.show_progress else nullcontext()) p_bar = stack.enter_context(cm1) proto = 'https' if self.args.https else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/post' iolet = await stack.enter_async_context( HTTPClientlet(url=url, logger=self.logger)) def _request_handler(request: 'Request') -> 'asyncio.Future': """ For HTTP Client, for each request in the iterator, we `send_message` using http POST request and add it to the list of tasks which is awaited and yielded. :param request: current request in the iterator :return: asyncio Task for sending message """ return asyncio.ensure_future( iolet.send_message(request=request)) def _result_handler(result): return result streamer = RequestStreamer( self.args, request_handler=_request_handler, result_handler=_result_handler, ) async for response in streamer.stream(request_iterator): r_status = response.status r_str = await response.json() if r_status == 404: raise BadClient(f'no such endpoint {url}') elif r_status < 200 or r_status > 300: raise ValueError(r_str) da = None if 'data' in r_str and r_str['data'] is not None: from docarray import DocumentArray da = DocumentArray.from_dict(r_str['data']) del r_str['data'] resp = DataRequest(r_str) if da is not None: resp.data.docs = da callback_exec( response=resp, on_error=on_error, on_done=on_done, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) if self.show_progress: p_bar.update() yield resp except aiohttp.ClientError as e: self.logger.error( f'Error while fetching response from HTTP server {e!r}')
def test_docs(req): request = DataRequest(request=req) docs = request.docs assert isinstance(docs, DocumentArray) assert len(docs) == 1
def test_copy(req): request = DataRequest(req) copied_req = copy.deepcopy(request) assert type(request) == type(copied_req) assert request == copied_req assert len(request.docs) == len(copied_req.docs)
def test_as_response(req): request = DataRequest(request=req) response = request.response assert isinstance(response, Response) assert isinstance(response, DataRequest) assert response._pb_body == request._pb_body
def test_access_header(req): request = DataRequest(request=req) assert request.header == req.header
def test_data_backwards_compatibility(req): req = DataRequest(request=req) assert len(req.data.docs) == 1 assert len(req.data.docs) == len(req.docs)
def create_req_from_text(text: str): req = DataRequest() da = DocumentArray() da.append(Document(text=text, tags={'key': 4})) req.data.docs = da return req
def test_as_json_str(req): request = DataRequest(request=req) assert isinstance(request.json(), str) request = DataRequest(request=None) assert isinstance(request.json(), str)