def test_grpc_ssl_with_flow_and_client(cert_pem, key_pem, error_log_level): with Flow( protocol='grpc', ssl_certfile=cert_pem, ssl_keyfile=key_pem, ) as flow: with open(cert_pem, 'rb') as f: creds = f.read() GrpcConnectionPool.send_request_sync( request=ControlRequest('STATUS'), target=f'localhost:{flow.port}', root_certificates=creds, tls=True, timeout=1.0, )
async def _dry_run(self, **kwargs) -> bool: """Sends a dry run to the Flow to validate if the Flow is ready to receive requests :param kwargs: potential kwargs received passed from the public interface :return: boolean indicating the health/readiness of the Flow """ try: async with GrpcConnectionPool.get_grpc_channel( f'{self.args.host}:{self.args.port}', asyncio=True, tls=self.args.tls, ) as channel: stub = jina_pb2_grpc.JinaGatewayDryRunRPCStub(channel) self.logger.debug(f'connected to {self.args.host}:{self.args.port}') call_result = stub.dry_run( jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(), **kwargs ) metadata, response = ( await call_result.trailing_metadata(), await call_result, ) if response.code == jina_pb2.StatusProto.SUCCESS: return True except Exception as e: self.logger.error(f'Error while getting response from grpc server {e!r}') return False
async def test_secure_send_request(private_key_cert_chain): server1_ready_event = multiprocessing.Event() (private_key, certificate_chain) = private_key_cert_chain def listen(port, event: multiprocessing.Event): class DummyServer: async def process_control(self, request, *args): returned_msg = ControlRequest(command='DEACTIVATE') return returned_msg async def start_grpc_server(): grpc_server = grpc.aio.server( options=[ ('grpc.max_send_request_length', -1), ('grpc.max_receive_message_length', -1), ] ) jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server( DummyServer(), grpc_server ) grpc_server.add_secure_port( f'localhost:{port}', grpc.ssl_server_credentials((private_key_cert_chain,)), ) await grpc_server.start() event.set() await grpc_server.wait_for_termination() asyncio.run(start_grpc_server()) port = random_port() server_process1 = Process( target=listen, args=( port, server1_ready_event, ), ) server_process1.start() time.sleep(0.1) server1_ready_event.wait() sent_msg = ControlRequest(command='STATUS') result = GrpcConnectionPool.send_request_sync( sent_msg, f'localhost:{port}', https=True, root_certificates=certificate_chain ) assert result.command == 'DEACTIVATE' result = await GrpcConnectionPool.send_request_async( sent_msg, f'localhost:{port}', https=True, root_certificates=certificate_chain ) assert result.command == 'DEACTIVATE' server_process1.kill() server_process1.join()
def is_ready(ctrl_address: str, **kwargs) -> bool: """ Check if status is ready. :param ctrl_address: the address where the control request needs to be sent :param kwargs: extra keyword arguments :return: True if status is ready else False. """ try: GrpcConnectionPool.send_request_sync(ControlRequest('STATUS'), ctrl_address) except RpcError as e: return False return True
def activate(self): """ Activate all worker pods in this deployment by registering them with the head """ if self.head_pod is not None: for shard_id in self.pod_args['pods']: for pod_idx, pod_args in enumerate( self.pod_args['pods'][shard_id]): worker_host = self.get_worker_host( pod_args, self.shards[shard_id]._pods[pod_idx], self.head_pod) GrpcConnectionPool.activate_worker_sync( worker_host, int(pod_args.port_in), self.head_pod.runtime_ctrl_address, shard_id, )
class GatewayRuntime(AsyncNewLoopRuntime, ABC): """ The Runtime from which the GatewayRuntimes need to inherit """ def __init__( self, args: argparse.Namespace, cancel_event: Optional[Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event']] = None, **kwargs, ): # this order is intentional: The timeout is needed in _set_topology_graph(), called by super self.timeout_send = args.timeout_send if self.timeout_send: self.timeout_send /= 1e3 # convert ms to seconds super().__init__(args, cancel_event, **kwargs) def _set_topology_graph(self): # check if it should be in K8s, maybe ConnectionPoolFactory to be created import json graph_description = json.loads(self.args.graph_description) graph_conditions = json.loads(self.args.graph_conditions) deployments_disable_reduce = json.loads( self.args.deployments_disable_reduce) self._topology_graph = TopologyGraph( graph_description, graph_conditions, deployments_disable_reduce, timeout_send=self.timeout_send, ) def _set_connection_pool(self): import json deployments_addresses = json.loads(self.args.deployments_addresses) # add the connections needed self._connection_pool = GrpcConnectionPool( logger=self.logger, compression=self.args.compression, metrics_registry=self.metrics_registry, ) for deployment_name, addresses in deployments_addresses.items(): for address in addresses: self._connection_pool.add_connection( deployment=deployment_name, address=address, head=True)
def test_dynamic_polling_with_config(polling): endpoint_polling = { '/any': PollingType.ANY, '/all': PollingType.ALL, '*': polling } args = set_deployment_parser().parse_args([ '--uses', 'DynamicPollingExecutor', '--shards', str(2), '--polling', json.dumps(endpoint_polling), ]) pod = Deployment(args) with pod: response = GrpcConnectionPool.send_request_sync( _create_test_data_message(endpoint='/all'), f'{pod.head_args.host}:{pod.head_args.port_in}', endpoint='/all', ) assert len(response.docs ) == 1 + 2 # 1 source doc + 2 docs added by each shard response = GrpcConnectionPool.send_request_sync( _create_test_data_message(endpoint='/any'), f'{pod.head_args.host}:{pod.head_args.port_in}', endpoint='/any', ) assert (len(response.docs) == 1 + 1 ) # 1 source doc + 1 doc added by the one shard response = GrpcConnectionPool.send_request_sync( _create_test_data_message(endpoint='/no_polling'), f'{pod.head_args.host}:{pod.head_args.port_in}', endpoint='/no_polling', ) if polling == 'any': assert (len(response.docs) == 1 + 1 ) # 1 source doc + 1 doc added by the one shard else: assert (len(response.docs) == 1 + 2 ) # 1 source doc + 1 doc added by the two shards
async def _send_requests(pod): response_texts = set() for _ in range(3): response = GrpcConnectionPool.send_request_sync( _create_test_data_message(), f'{pod.head_args.host}:{pod.head_args.port_in}', ) response_texts.update(response.response.docs.texts) return response_texts
async def test_pods_trivial_topology(head_runtime_docker_image_built, worker_runtime_docker_image_built): worker_port = random_port() head_port = random_port() port = random_port() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' # create a single worker pod worker_pod = _create_worker_pod(worker_port) # create a single head pod head_pod = _create_head_pod(head_port) # create a single gateway pod gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port) with gateway_pod, head_pod, worker_pod: await asyncio.sleep(1.0) assert HeadRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=head_pod.runtime_ctrl_address, ready_or_shutdown_event=head_pod.ready_or_shutdown.event, ) assert WorkerRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=worker_pod.runtime_ctrl_address, ready_or_shutdown_event=worker_pod.ready_or_shutdown.event, ) head_pod.ready_or_shutdown.event.wait(timeout=5.0) worker_pod.ready_or_shutdown.event.wait(timeout=5.0) gateway_pod.ready_or_shutdown.event.wait(timeout=5.0) # this would be done by the Pod, its adding the worker to the head activate_msg = ControlRequest(command='ACTIVATE') worker_host, worker_port = worker_pod.runtime_ctrl_address.split(':') activate_msg.add_related_entity('worker', worker_host, int(worker_port)) assert GrpcConnectionPool.send_request_sync( activate_msg, head_pod.runtime_ctrl_address) # send requests to the gateway c = Client(host='localhost', port=port, asyncio=True) responses = c.post('/', inputs=async_inputs, request_size=1, return_responses=True) response_list = [] async for response in responses: response_list.append(response) assert len(response_list) == 20 assert len(response_list[0].docs) == 1
def test_dynamic_polling(polling): args = set_pod_parser().parse_args([ '--polling', json.dumps({ '/any': PollingType.ANY, '/all': PollingType.ALL, '*': polling }), '--shards', str(2), ]) connection_list_dict = {0: [f'fake_ip:8080'], 1: [f'fake_ip:8080']} args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message(endpoint='all'), metadata=(('endpoint', '/all'), )) assert response assert _queue_length(handle_queue) == 2 with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message(endpoint='any'), metadata=(('endpoint', '/any'), )) assert response assert _queue_length(handle_queue) == 3 _destroy_runtime(args, cancel_event, runtime_thread)
async def task_wrapper(adress, messages_received): request = _create_test_data_message(len(messages_received)) ( single_data_stub, data_stub, control_stub, channel, ) = GrpcConnectionPool.create_async_channel_stub(adress) await data_stub.process_data(request) await channel.close() messages_received.append(request)
async def test_pseudo_remote_pods_topologies(gateway, head, worker): """ g(l)-h(l)-w(l) - works g(l)-h(l)-w(r) - works - head connects to worker via localhost g(l)-h(r)-w(r) - works - head (inside docker) connects to worker via dockerhost g(l)-h(r)-w(l) - doesn't work remote head need remote worker g(r)-... - doesn't work, as distributed parser not enabled for gateway After any 1 failure, segfault """ worker_port = random_port() head_port = random_port() port_expose = random_port() graph_description = ( '{"start-gateway": ["deployment0"], "deployment0": ["end-gateway"]}') if head == 'remote': deployments_addresses = f'{{"deployment0": ["{HOST}:{head_port}"]}}' else: deployments_addresses = f'{{"deployment0": ["0.0.0.0:{head_port}"]}}' # create a single head pod head_pod = _create_head_pod(head, head_port) # create a single worker pod worker_pod = _create_worker_pod(worker, worker_port) # create a single gateway pod gateway_pod = _create_gateway_pod(gateway, graph_description, deployments_addresses, port_expose) with gateway_pod, worker_pod, head_pod: await asyncio.sleep(1.0) # this would be done by the deployment, its adding the worker to the head activate_msg = ControlRequest(command='ACTIVATE') worker_host, worker_port = worker_pod.runtime_ctrl_address.split(':') if head == 'remote': worker_host = __docker_host__ activate_msg.add_related_entity('worker', worker_host, int(worker_port)) assert GrpcConnectionPool.send_request_sync( activate_msg, head_pod.runtime_ctrl_address) # send requests to the gateway c = Client(host='127.0.0.1', port=port_expose, asyncio=True) responses = c.post('/', inputs=async_inputs, request_size=1, return_results=True) response_list = [] async for response in responses: response_list.append(response) assert len(response_list) == 20 assert len(response_list[0].docs) == 1
def test_custom_swagger(p): args = set_gateway_parser().parse_args(p) logger = JinaLogger('') app = get_fastapi_app(args, TopologyGraph({}), GrpcConnectionPool(logger=logger), logger) # The TestClient is needed here as a context manager to generate the shutdown event correctly # otherwise the app can hang as it is not cleaned up correctly # see https://fastapi.tiangolo.com/advanced/testing-events/ with TestClient(app) as client: assert any('/docs' in r.path for r in app.routes) assert any('/openapi.json' in r.path for r in app.routes)
def check_health_pod(addr: str): """check if a pods is healthy :param addr: the address on which the pod is serving ex : localhost:1234 """ import grpc from jina.serve.networking import GrpcConnectionPool from jina.types.request.control import ControlRequest try: GrpcConnectionPool.send_request_sync( request=ControlRequest('STATUS'), target=addr, ) except grpc.RpcError as e: print('The pod is unhealthy') print(e) raise e print('The pod is healthy')
async def test_blocking_sync_exec(): SLEEP_TIME = 0.01 REQUEST_COUNT = 100 class BlockingExecutor(Executor): @requests def foo(self, docs: DocumentArray, **kwargs): time.sleep(SLEEP_TIME) for doc in docs: doc.text = 'BlockingExecutor' return docs args = set_pod_parser().parse_args(['--uses', 'BlockingExecutor']) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) send_tasks = [] start_time = time.time() for i in range(REQUEST_COUNT): send_tasks.append( asyncio.create_task( GrpcConnectionPool.send_request_async( _create_test_data_message(), target=f'{args.host}:{args.port}', timeout=3.0, ))) results = await asyncio.gather(*send_tasks) end_time = time.time() assert all(result.docs.texts == ['BlockingExecutor'] for result in results) assert end_time - start_time < (REQUEST_COUNT * SLEEP_TIME) * 2.0 cancel_event.set() runtime_thread.join()
def test_control_message_processing(): args = set_pod_parser().parse_args([]) cancel_event, handle_queue, runtime_thread = _create_runtime(args) # no connection registered yet resp = GrpcConnectionPool.send_request_sync(_create_test_data_message(), f'{args.host}:{args.port}') assert resp.status.code == resp.status.ERROR _add_worker(args, 'ip1') # after adding a connection, sending should work result = GrpcConnectionPool.send_request_sync(_create_test_data_message(), f'{args.host}:{args.port}') assert result _remove_worker(args, 'ip1') # after removing the connection again, sending does not work anymore resp = GrpcConnectionPool.send_request_sync(_create_test_data_message(), f'{args.host}:{args.port}') assert resp.status.code == resp.status.ERROR _destroy_runtime(args, cancel_event, runtime_thread)
async def test_worker_runtime_slow_async_exec(uses): args = set_pod_parser().parse_args(['--uses', uses]) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) target = f'{args.host}:{args.port}' results = [] async with grpc.aio.insecure_channel( target, options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) tasks = [] for i in range(10): async def task_wrapper(): return await stub.process_single_data( _create_test_data_message()) tasks.append(asyncio.create_task(task_wrapper())) for future in asyncio.as_completed(tasks): t = await future results.append(t.docs[0].text) cancel_event.set() runtime_thread.join() if uses == 'AsyncSlowNewDocsExecutor': assert results == ['1', '3', '5', '7', '9', '2', '4', '6', '8', '10'] else: assert results == ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')
def test_base_polling(polling): args = set_pod_parser().parse_args([ '--polling', polling, '--shards', str(2), ]) cancel_event, handle_queue, runtime_thread = _create_runtime(args) _add_worker(args, shard_id=0) _add_worker(args, shard_id=1) with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message(endpoint='all'), metadata=(('endpoint', '/all'), )) assert response assert _queue_length(handle_queue) == 2 if polling == 'all' else 1 with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message(endpoint='any'), metadata=(('endpoint', '/any'), )) assert response assert _queue_length(handle_queue) == 4 if polling == 'all' else 2 _destroy_runtime(args, cancel_event, runtime_thread)
def test_pod_activates_replicas(): args_list = ['--replicas', '3', '--shards', '2', '--disable-reduce'] args = set_deployment_parser().parse_args(args_list) args.uses = 'AppendNameExecutor' with Deployment(args) as pod: assert pod.num_pods == 7 response_texts = set() # replicas are used in a round robin fashion, so sending 3 requests should hit each one time for _ in range(6): response = GrpcConnectionPool.send_request_sync( _create_test_data_message(), f'{pod.head_args.host}:{pod.head_args.port}', ) response_texts.update(response.response.docs.texts) assert 4 == len(response_texts) assert all(text in response_texts for text in ['0', '1', '2', 'client']) Deployment(args).start().close()
def test_message_merging(): args = set_pod_parser().parse_args([]) args.polling = PollingType.ALL cancel_event, handle_queue, runtime_thread = _create_runtime(args) assert handle_queue.empty() _add_worker(args, 'ip1', shard_id=0) _add_worker(args, 'ip2', shard_id=1) _add_worker(args, 'ip3', shard_id=2) assert handle_queue.empty() result = GrpcConnectionPool.send_request_sync( _create_test_data_message(), f'{args.host}:{args.port_in}') assert result assert _queue_length(handle_queue) == 3 assert len(result.response.docs) == 1 _destroy_runtime(args, cancel_event, runtime_thread)
def test_error_in_worker_runtime(monkeypatch): args = set_pod_parser().parse_args([]) cancel_event = multiprocessing.Event() def fail(*args, **kwargs): raise RuntimeError('intentional error') monkeypatch.setattr(DataRequestHandler, 'handle', fail) def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) target = f'{args.host}:{args.port}' with grpc.insecure_channel( target, options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message()) assert response.header.status.code == jina_pb2.StatusProto.ERROR assert 'is-error' in dict(call.trailing_metadata()) cancel_event.set() runtime_thread.join() assert response assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')
def is_ready(ctrl_address: str, **kwargs) -> bool: """ Check if status is ready. :param ctrl_address: the address where the control request needs to be sent :param kwargs: extra keyword arguments :return: True if status is ready else False. """ try: from grpc_health.v1 import health_pb2, health_pb2_grpc response = GrpcConnectionPool.send_health_check_sync( ctrl_address, timeout=1.0 ) # TODO: Get the proper value of the ServingStatus SERVING KEY return response.status == 1 except RpcError: return False
def test_regular_data_case(): args = set_pod_parser().parse_args([]) args.polling = PollingType.ANY connection_list_dict = {0: [f'fake_ip:8080']} args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message()) assert response assert 'is-error' in dict(call.trailing_metadata()) assert len(response.docs) == 1 assert not handle_queue.empty() _destroy_runtime(args, cancel_event, runtime_thread)
def test_timeout_behaviour(): args = set_pod_parser().parse_args(['--timeout-send', '100']) args.polling = PollingType.ANY cancel_event, handle_queue, runtime_thread = _create_runtime(args) _add_worker(args) with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message()) assert response assert 'is-error' in dict(call.trailing_metadata()) assert len(response.docs) == 1 assert not handle_queue.empty() _destroy_runtime(args, cancel_event, runtime_thread)
def test_pod_activates_shards(): args_list = ['--replicas', '3'] args_list.extend(['--shards', '3']) args = set_deployment_parser().parse_args(args_list) args.uses = 'AppendShardExecutor' args.polling = PollingType.ALL with Deployment(args) as pod: assert pod.num_pods == 3 * 3 + 1 response_texts = set() # replicas are used in a round robin fashion, so sending 3 requests should hit each one time response = GrpcConnectionPool.send_request_sync( _create_test_data_message(), f'{pod.head_args.host}:{pod.head_args.port_in}', ) response_texts.update(response.response.docs.texts) assert 4 == len(response.response.docs.texts) assert 4 == len(response_texts) assert all(text in response_texts for text in ['0', '1', '2', 'client']) Deployment(args).start().close()
async def test_pods_trivial_topology(port_generator): worker_port = port_generator() head_port = port_generator() port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' # create a single worker pod worker_pod = _create_worker_pod(worker_port) # create a single head pod head_pod = _create_head_pod(head_port) # create a single gateway pod gateway_pod = _create_gateway_pod(graph_description, pod_addresses, port) with gateway_pod, head_pod, worker_pod: # this would be done by the Pod, its adding the worker to the head head_pod.wait_start_success() worker_pod.wait_start_success() activate_msg = ControlRequest(command='ACTIVATE') activate_msg.add_related_entity('worker', '127.0.0.1', worker_port) assert GrpcConnectionPool.send_request_sync( activate_msg, f'127.0.0.1:{head_port}' ) # send requests to the gateway gateway_pod.wait_start_success() c = Client(host='localhost', port=port, asyncio=True) responses = c.post( '/', inputs=async_inputs, request_size=1, return_responses=True ) response_list = [] async for response in responses: response_list.append(response) assert len(response_list) == 20 assert len(response_list[0].docs) == 1
def test_decompress(monkeypatch): call_counts = multiprocessing.Manager().Queue() def decompress(self): call_counts.put_nowait('called') from jina.proto import jina_pb2 self._pb_body = jina_pb2.DataRequestProto() self._pb_body.ParseFromString(self.buffer) self.buffer = None monkeypatch.setattr( DataRequest, '_decompress', decompress, ) args = set_pod_parser().parse_args([]) args.polling = PollingType.ANY cancel_event, handle_queue, runtime_thread = _create_runtime(args) _add_worker(args) with grpc.insecure_channel( f'{args.host}:{args.port}', options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message()) assert response assert 'is-error' in dict(call.trailing_metadata()) assert _queue_length_copy(call_counts) == 0 assert len(response.docs) == 1 assert _queue_length_copy(call_counts) == 1 assert not handle_queue.empty() _destroy_runtime(args, cancel_event, runtime_thread)
def test_message_merging(disable_reduce): if not disable_reduce: args = set_pod_parser().parse_args([]) else: args = set_pod_parser().parse_args(['--disable-reduce']) args.polling = PollingType.ALL cancel_event, handle_queue, runtime_thread = _create_runtime(args) assert handle_queue.empty() _add_worker(args, 'ip1', shard_id=0) _add_worker(args, 'ip2', shard_id=1) _add_worker(args, 'ip3', shard_id=2) assert handle_queue.empty() data_request = _create_test_data_message() result = GrpcConnectionPool.send_requests_sync( [data_request, data_request], f'{args.host}:{args.port}') assert result assert _queue_length(handle_queue) == 3 assert len(result.response.docs) == 2 if disable_reduce else 1 _destroy_runtime(args, cancel_event, runtime_thread)
def test_uses_before_uses_after(): args = set_pod_parser().parse_args([]) args.polling = PollingType.ALL args.uses_before_address = 'fake_address' args.uses_after_address = 'fake_address' connection_list_dict = { 0: [f'ip1:8080'], 1: [f'ip2:8080'], 2: [f'ip3:8080'] } args.connection_list = json.dumps(connection_list_dict) cancel_event, handle_queue, runtime_thread = _create_runtime(args) assert handle_queue.empty() result = GrpcConnectionPool.send_request_sync(_create_test_data_message(), f'{args.host}:{args.port}') assert result assert _queue_length( handle_queue) == 5 # uses_before + 3 workers + uses_after assert len(result.response.docs) == 1 _destroy_runtime(args, cancel_event, runtime_thread)
def test_worker_runtime(): args = set_pod_parser().parse_args([]) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) target = f'{args.host}:{args.port}' with grpc.insecure_channel( target, options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message()) cancel_event.set() runtime_thread.join() assert response assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')