def test_jina_info_grpc_based_runtimes(runtime, port_generator): port = port_generator() connection_list_dict = {} graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{port}"]}}' if runtime == 'head': p = _create_head(port, connection_list_dict) elif runtime == 'gateway': p = _create_gateway(port, graph_description, pod_addresses, 'grpc') else: p = _create_worker(port) try: AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) channel = grpc.insecure_channel(f'localhost:{port}') stub = jina_pb2_grpc.JinaInfoRPCStub(channel) res = stub._status( jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(), ) assert res.jina['jina'] == __version__ for env_var in __jina_env__: assert env_var in res.envs except Exception: assert False finally: p.terminate() p.join()
def test_jina_info_gateway_http(protocol, port_generator): port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{port}"]}}' p = _create_gateway(port, graph_description, pod_addresses, protocol) try: AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) x = requests.get(f'http://localhost:{port}/status') resp = x.json() assert 'jina' in resp assert 'envs' in resp assert resp['jina']['jina'] == __version__ for env_var in __jina_env__: assert env_var in resp['envs'] except Exception: assert False finally: p.terminate() p.join()
async def test_decorator_monitoring(port_generator): from jina import monitor class DummyExecutor(Executor): @requests def foo(self, docs, **kwargs): self._proces(docs) self.process_2(docs) @monitor(name='metrics_name', documentation='metrics description') def _proces(self, docs): ... @monitor() def process_2(self, docs): ... port = port_generator() args = set_pod_parser().parse_args([ '--monitoring', '--port-monitoring', str(port), '--uses', 'DummyExecutor' ]) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) await GrpcConnectionPool.send_request_async(_create_test_data_message(), f'{args.host}:{args.port}', timeout=1.0) resp = req.get(f'http://localhost:{port}/') assert f'jina_metrics_name_count{{runtime_name="None"}} 1.0' in str( resp.content) cancel_event.set() runtime_thread.join() assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')
async def _activate_runtimes(head_port, worker_ports): for i, worker_port in enumerate(worker_ports): AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ready_or_shutdown_event=threading.Event(), ctrl_address=f'127.0.0.1:{worker_port}', ) await _activate_worker(head_port, worker_port, shard_id=i)
def test_custom_num_retries(port_generator, retries, capfd): # test that the user can set the number of grpc retries for failed calls # if negative number is given, test that default policy applies: hit every replica at least once # create gateway and workers manually, then terminate worker process to provoke an error num_replicas = 3 worker_ports = [port_generator() for _ in range(num_replicas)] worker0_port, worker1_port, worker2_port = worker_ports gateway_port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{worker0_port}", "0.0.0.0:{worker1_port}", "0.0.0.0:{worker2_port}"]}}' worker_processes = [] for p in worker_ports: worker_processes.append(_create_worker(p)) time.sleep(0.1) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{p}', ready_or_shutdown_event=multiprocessing.Event(), ) gateway_process = _create_gateway( gateway_port, graph_description, pod_addresses, 'grpc', retries=retries ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{gateway_port}', ready_or_shutdown_event=multiprocessing.Event(), ) try: # ----------- 1. ping Flow once to trigger endpoint discovery ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process(target=_send_request, args=(gateway_port, 'grpc')) p.start() p.join() assert p.exitcode == 0 # kill all workers for p in worker_processes: p.terminate() p.join() # ----------- 2. test that call will be retried the appropriate number of times ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process( target=_test_custom_retry, args=(gateway_port, worker_ports, 'grpc', retries, capfd), ) p.start() p.join() assert p.exitcode == 0 except Exception: assert False finally: # clean up runtimes gateway_process.terminate() gateway_process.join() for p in worker_processes: p.terminate() p.join()
async def test_runtimes_graphql(port_generator): # create gateway and workers manually, then terminate worker process to provoke an error protocol = 'http' worker_port = port_generator() gateway_port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{worker_port}"]}}' worker_process = _create_worker(worker_port) gateway_process = _create_gqlgateway(gateway_port, graph_description, pod_addresses) time.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{worker_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{gateway_port}', ready_or_shutdown_event=multiprocessing.Event(), ) worker_process.terminate() # kill worker worker_process.join() try: # ----------- 1. test that useful errors are given ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process( target=_test_gql_error, args=(gateway_port, worker_port) ) p.start() p.join() assert ( p.exitcode == 0 ) # if exitcode != 0 then test in other process did not pass and this should fail # ----------- 2. test that gateways remain alive ----------- # just do the same again, expecting the same outcome p = multiprocessing.Process( target=_test_gql_error, args=(gateway_port, worker_port) ) p.start() p.join() assert ( p.exitcode == 0 ) # if exitcode != 0 then test in other process did not pass and this should fail except Exception: raise finally: # clean up runtimes gateway_process.terminate() worker_process.terminate() gateway_process.join() worker_process.join()
async def test_replica_retry_all_fail(port_generator): # test that if one replica is down, the other replica(s) will be used # create gateway and workers manually, then terminate worker process to provoke an error worker_ports = [port_generator() for _ in range(3)] worker0_port, worker1_port, worker2_port = worker_ports gateway_port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{worker0_port}", "0.0.0.0:{worker1_port}", "0.0.0.0:{worker2_port}"]}}' worker_processes = [] for p in worker_ports: worker_processes.append(_create_worker(p)) time.sleep(0.1) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{p}', ready_or_shutdown_event=multiprocessing.Event(), ) gateway_process = _create_gateway( gateway_port, graph_description, pod_addresses, 'grpc' ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{gateway_port}', ready_or_shutdown_event=multiprocessing.Event(), ) try: # ----------- 1. ping Flow once to trigger endpoint discovery ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process(target=_send_request, args=(gateway_port, 'grpc')) p.start() p.join() assert p.exitcode == 0 # kill all workers for p in worker_processes: p.terminate() p.join() # ----------- 2. test that call fails with informative error message ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process( target=_test_error, args=(gateway_port, worker_ports, 'grpc') ) p.start() p.join() assert p.exitcode == 0 except Exception: assert False finally: # clean up runtimes gateway_process.terminate() gateway_process.join() for p in worker_processes: p.terminate() p.join()
async def test_runtimes_replicas(port_generator, protocol): # create gateway and workers manually, then terminate worker process to provoke an error worker_ports = [port_generator() for _ in range(3)] worker0_port, worker1_port, worker2_port = worker_ports gateway_port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{worker0_port}", "0.0.0.0:{worker1_port}", "0.0.0.0:{worker2_port}"]}}' worker_processes = [] for p in worker_ports: worker_processes.append(_create_worker(p)) time.sleep(0.1) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{p}', ready_or_shutdown_event=multiprocessing.Event(), ) gateway_process = _create_gateway( gateway_port, graph_description, pod_addresses, protocol ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{gateway_port}', ready_or_shutdown_event=multiprocessing.Event(), ) worker_processes[0].terminate() # kill 'middle' worker worker_processes[0].join() try: # await _send_request(gateway_port, protocol) # ----------- 1. test that useful errors are given ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process( target=_test_error, args=(gateway_port, worker0_port, protocol) ) p.start() p.join() assert ( p.exitcode == 0 ) # if exitcode != 0 then test in other process did not pass and this should fail # no retry in the case with replicas, because round robin retry mechanism will pick different replica now except Exception: assert False finally: # clean up runtimes gateway_process.terminate() gateway_process.join() for p in worker_processes: p.terminate() p.join()
def _create_runtime(args): handle_queue = multiprocessing.Queue() cancel_event = multiprocessing.Event() def start_runtime(args, handle_queue, cancel_event): def _send_requests_mock(request: List[Request], connection, endpoint) -> asyncio.Task: async def mock_task_wrapper(new_requests, *args, **kwargs): handle_queue.put('mock_called') await asyncio.sleep(0.1) return new_requests[0], grpc.aio.Metadata.from_tuple( (('is-error', 'true'), )) return asyncio.create_task(mock_task_wrapper(request, connection)) if not hasattr(args, 'name') or not args.name: args.name = 'testHead' with HeadRuntime(args, cancel_event) as runtime: runtime.connection_pool._send_requests = _send_requests_mock runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, handle_queue, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port_in}', ready_or_shutdown_event=multiprocessing.Event(), ) return cancel_event, handle_queue, runtime_thread
async def test_worker_runtime_reflection(): args = set_pod_parser().parse_args([]) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=3.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) async with grpc.aio.insecure_channel( f'{args.host}:{args.port}') as channel: service_names = await GrpcConnectionPool.get_available_services(channel ) assert all(service_name in service_names for service_name in [ 'jina.JinaDataRequestRPC', 'jina.JinaSingleDataRequestRPC', ]) cancel_event.set() runtime_thread.join() assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')
def test_dry_run_of_flow(port_generator, protocol): worker_port = port_generator() port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{worker_port}"]}}' # create a single worker runtime worker_process = multiprocessing.Process(target=_create_worker_runtime, args=(worker_port, )) worker_process.start() # create a single gateway runtime gateway_process = multiprocessing.Process( target=_create_gateway_runtime, args=(graph_description, pod_addresses, port, protocol), ) gateway_process.start() AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{worker_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) # send requests to the gateway c = Client(host='localhost', port=port, asyncio=True, protocol=protocol) dry_run_alive = c.dry_run() worker_process.terminate() worker_process.join() dry_run_worker_removed = c.dry_run() gateway_process.terminate() gateway_process.join() assert dry_run_alive assert not dry_run_worker_removed assert gateway_process.exitcode == 0 assert worker_process.exitcode == 0
async def test_runtimes_gateway_worker_direct_connection(port_generator): worker_port = port_generator() port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{worker_port}"]}}' # create the shards worker_process = multiprocessing.Process(target=_create_worker_runtime, args=(worker_port, f'pod0')) worker_process.start() await asyncio.sleep(0.1) # create a single gateway runtime gateway_process = multiprocessing.Process( target=_create_gateway_runtime, args=(graph_description, pod_addresses, port), ) gateway_process.start() await asyncio.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) c = Client(host='localhost', port=port, asyncio=True) responses = c.post('/', inputs=async_inputs, request_size=1, return_responses=True) response_list = [] async for response in responses: response_list.append(response) # clean up runtimes gateway_process.terminate() worker_process.terminate() gateway_process.join() worker_process.join() assert len(response_list) == 20 assert len(response_list[0].docs) == 1 assert gateway_process.exitcode == 0 assert worker_process.exitcode == 0
async def test_blocking_sync_exec(): SLEEP_TIME = 0.01 REQUEST_COUNT = 100 class BlockingExecutor(Executor): @requests def foo(self, docs: DocumentArray, **kwargs): time.sleep(SLEEP_TIME) for doc in docs: doc.text = 'BlockingExecutor' return docs args = set_pod_parser().parse_args(['--uses', 'BlockingExecutor']) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) send_tasks = [] start_time = time.time() for i in range(REQUEST_COUNT): send_tasks.append( asyncio.create_task( GrpcConnectionPool.send_request_async( _create_test_data_message(), target=f'{args.host}:{args.port}', timeout=3.0, ))) results = await asyncio.gather(*send_tasks) end_time = time.time() assert all(result.docs.texts == ['BlockingExecutor'] for result in results) assert end_time - start_time < (REQUEST_COUNT * SLEEP_TIME) * 2.0 cancel_event.set() runtime_thread.join()
def _wait_for_ready_or_shutdown(self, timeout: Optional[float]): """ Waits for the process to be ready or to know it has failed. :param timeout: The time to wait before readiness or failure is determined .. # noqa: DAR201 """ return AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=timeout, ready_or_shutdown_event=self.ready_or_shutdown.event, ctrl_address=self.runtime_ctrl_address, timeout_ctrl=self._timeout_ctrl, )
async def test_worker_runtime_slow_async_exec(uses): args = set_pod_parser().parse_args(['--uses', uses]) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) target = f'{args.host}:{args.port}' results = [] async with grpc.aio.insecure_channel( target, options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) tasks = [] for i in range(10): async def task_wrapper(): return await stub.process_single_data( _create_test_data_message()) tasks.append(asyncio.create_task(task_wrapper())) for future in asyncio.as_completed(tasks): t = await future results.append(t.docs[0].text) cancel_event.set() runtime_thread.join() if uses == 'AsyncSlowNewDocsExecutor': assert results == ['1', '3', '5', '7', '9', '2', '4', '6', '8', '10'] else: assert results == ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')
async def test_worker_runtime_slow_init_exec(): args = set_pod_parser().parse_args(['--uses', 'SlowInitExecutor']) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_started = time.time() runtime_thread.start() # wait a bit to the worker runtime has a chance to finish some things, but not the Executor init (5 secs) time.sleep(1.0) # try to connect a TCP socket to the gRPC server # this should only succeed after the Executor is ready, which should be after 5 seconds with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: connected = False while not connected: try: s.connect((args.host, args.port)) connected = True except ConnectionRefusedError: time.sleep(0.2) # Executor sleeps 5 seconds, so at least 5 seconds need to have elapsed here assert time.time() - runtime_started > 5.0 assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=3.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) result = await GrpcConnectionPool.send_request_async( _create_test_data_message(), f'{args.host}:{args.port}', timeout=1.0 ) assert len(result.docs) == 1 cancel_event.set() runtime_thread.join() assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')
def test_error_in_worker_runtime(monkeypatch): args = set_pod_parser().parse_args([]) cancel_event = multiprocessing.Event() def fail(*args, **kwargs): raise RuntimeError('intentional error') monkeypatch.setattr(DataRequestHandler, 'handle', fail) def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) target = f'{args.host}:{args.port}' with grpc.insecure_channel( target, options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message()) assert response.header.status.code == jina_pb2.StatusProto.ERROR assert 'is-error' in dict(call.trailing_metadata()) cancel_event.set() runtime_thread.join() assert response assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')
async def test_head_runtime_reflection(): args = set_pod_parser().parse_args([]) cancel_event, handle_queue, runtime_thread = _create_runtime(args) assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=3.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=multiprocessing.Event(), ) async with grpc.aio.insecure_channel( f'{args.host}:{args.port}') as channel: service_names = await GrpcConnectionPool.get_available_services(channel ) assert all(service_name in service_names for service_name in [ 'jina.JinaDataRequestRPC', 'jina.JinaSingleDataRequestRPC', ]) _destroy_runtime(args, cancel_event, runtime_thread)
def test_worker_runtime(): args = set_pod_parser().parse_args([]) cancel_event = multiprocessing.Event() def start_runtime(args, cancel_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event), daemon=True, ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) target = f'{args.host}:{args.port}' with grpc.insecure_channel( target, options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) response, call = stub.process_single_data.with_call( _create_test_data_message()) cancel_event.set() runtime_thread.join() assert response assert not AsyncNewLoopRuntime.is_ready(f'{args.host}:{args.port}')
async def test_runtimes_trivial_topology(port_generator): worker_port = port_generator() head_port = port_generator() port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' # create a single worker runtime worker_process = multiprocessing.Process(target=_create_worker_runtime, args=(worker_port, )) worker_process.start() # create a single head runtime head_process = multiprocessing.Process(target=_create_head_runtime, args=(head_port, )) head_process.start() # create a single gateway runtime gateway_process = multiprocessing.Process( target=_create_gateway_runtime, args=(graph_description, pod_addresses, port), ) gateway_process.start() await asyncio.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{head_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{worker_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) # this would be done by the Pod, its adding the worker to the head activate_msg = ControlRequest(command='ACTIVATE') activate_msg.add_related_entity('worker', '127.0.0.1', worker_port) GrpcConnectionPool.send_request_sync(activate_msg, f'127.0.0.1:{head_port}') # send requests to the gateway c = Client(host='localhost', port=port, asyncio=True) responses = c.post('/', inputs=async_inputs, request_size=1, return_responses=True) response_list = [] async for response in responses: response_list.append(response) # clean up runtimes gateway_process.terminate() head_process.terminate() worker_process.terminate() gateway_process.join() head_process.join() worker_process.join() assert len(response_list) == 20 assert len(response_list[0].docs) == 1 assert gateway_process.exitcode == 0 assert head_process.exitcode == 0 assert worker_process.exitcode == 0
async def test_runtimes_headful_topology(port_generator, protocol, terminate_head): # create gateway and workers manually, then terminate worker process to provoke an error worker_port = port_generator() gateway_port = port_generator() head_port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' connection_list_dict = {'0': [f'127.0.0.1:{worker_port}']} head_process = _create_head(head_port, connection_list_dict, 'ANY') worker_process = _create_worker(worker_port) gateway_process = _create_gateway( gateway_port, graph_description, pod_addresses, protocol ) time.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{head_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{worker_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{gateway_port}', ready_or_shutdown_event=multiprocessing.Event(), ) # terminate pod, either head or worker behind the head if terminate_head: head_process.terminate() head_process.join() error_port = head_port else: worker_process.terminate() # kill worker worker_process.join() error_port = worker_port error_port = ( head_port if protocol == 'websocket' else error_port ) # due to error msg length constraints ws will always report the head address try: # ----------- 1. test that useful errors are given ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process( target=_test_error, args=(gateway_port, error_port, protocol) ) p.start() p.join() assert ( p.exitcode == 0 ) # if exitcode != 0 then test in other process did not pass and this should fail # ----------- 2. test that gateways remain alive ----------- # just do the same again, expecting the same outcome p = multiprocessing.Process( target=_test_error, args=(gateway_port, error_port, protocol) ) p.start() p.join() assert ( p.exitcode == 0 ) # if exitcode != 0 then test in other process did not pass and this should fail except Exception: raise finally: # clean up runtimes gateway_process.terminate() worker_process.terminate() head_process.terminate() gateway_process.join() worker_process.join() head_process.join()
async def test_worker_runtime_graceful_shutdown(): args = set_pod_parser().parse_args([]) cancel_event = multiprocessing.Event() handler_closed_event = multiprocessing.Event() slow_executor_block_time = 1.0 pending_requests = 5 def start_runtime(args, cancel_event, handler_closed_event): with WorkerRuntime(args, cancel_event=cancel_event) as runtime: runtime._data_request_handler.handle = lambda *args, **kwargs: time.sleep( slow_executor_block_time ) runtime._data_request_handler.close = ( lambda *args, **kwargs: handler_closed_event.set() ) runtime.run_forever() runtime_thread = Process( target=start_runtime, args=(args, cancel_event, handler_closed_event), ) runtime_thread.start() assert AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'{args.host}:{args.port}', ready_or_shutdown_event=Event(), ) request_start_time = time.time() async def task_wrapper(adress, messages_received): request = _create_test_data_message(len(messages_received)) ( single_data_stub, data_stub, control_stub, channel, ) = GrpcConnectionPool.create_async_channel_stub(adress) await data_stub.process_data(request) await channel.close() messages_received.append(request) sent_requests = 0 messages_received = [] tasks = [] for i in range(pending_requests): tasks.append( asyncio.create_task( task_wrapper(f'{args.host}:{args.port}', messages_received) ) ) sent_requests += 1 await asyncio.sleep(1.0) runtime_thread.terminate() assert not handler_closed_event.is_set() runtime_thread.join() for future in asyncio.as_completed(tasks): _ = await future assert pending_requests == sent_requests assert sent_requests == len(messages_received) assert ( time.time() - request_start_time >= slow_executor_block_time * pending_requests ) assert handler_closed_event.is_set() assert not WorkerRuntime.is_ready(f'{args.host}:{args.port}')
async def test_runtimes_flow_topology(complete_graph_dict, uses_before, uses_after, port_generator): pods = [ pod_name for pod_name in complete_graph_dict.keys() if 'gateway' not in pod_name ] runtime_processes = [] pod_addresses = '{' for pod in pods: if uses_before: uses_before_port, uses_before_process = await _create_worker( pod, port_generator, type='uses_before') AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ready_or_shutdown_event=threading.Event(), ctrl_address=f'127.0.0.1:{uses_before_port}', ) runtime_processes.append(uses_before_process) if uses_after: uses_after_port, uses_after_process = await _create_worker( pod, port_generator, type='uses_after') AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ready_or_shutdown_event=threading.Event(), ctrl_address=f'127.0.0.1:{uses_after_port}', ) runtime_processes.append(uses_after_process) # create head head_port = port_generator() pod_addresses += f'"{pod}": ["0.0.0.0:{head_port}"],' head_process = multiprocessing.Process( target=_create_head_runtime, args=( head_port, f'{pod}/head', 'ANY', f'127.0.0.1:{uses_before_port}' if uses_before else None, f'127.0.0.1:{uses_after_port}' if uses_after else None, ), ) runtime_processes.append(head_process) head_process.start() # create worker worker_port, worker_process = await _create_worker(pod, port_generator) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ready_or_shutdown_event=threading.Event(), ctrl_address=f'127.0.0.1:{worker_port}', ) runtime_processes.append(worker_process) await asyncio.sleep(0.1) await _activate_worker(head_port, worker_port) # remove last comma pod_addresses = pod_addresses[:-1] pod_addresses += '}' port = port_generator() # create a single gateway runtime gateway_process = multiprocessing.Process( target=_create_gateway_runtime, args=(json.dumps(complete_graph_dict), pod_addresses, port), ) gateway_process.start() await asyncio.sleep(0.1) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) # send requests to the gateway c = Client(host='localhost', port=port, asyncio=True) responses = c.post('/', inputs=async_inputs, request_size=1, return_responses=True) response_list = [] async for response in responses: response_list.append(response) # clean up runtimes gateway_process.terminate() for process in runtime_processes: process.terminate() gateway_process.join() for process in runtime_processes: process.join() assert len(response_list) == 20 assert len(response_list[0].docs) == 1 assert gateway_process.exitcode == 0 for process in runtime_processes: assert process.exitcode == 0
async def test_runtimes_with_replicas_advance_faster(port_generator): head_port = port_generator() port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' # create a single head runtime head_process = multiprocessing.Process(target=_create_head_runtime, args=(head_port, 'head')) head_process.start() # create the shards replica_processes = [] worker_ports = [] for i in range(10): # create worker worker_port = port_generator() # create a single worker runtime worker_process = multiprocessing.Process( target=_create_worker_runtime, args=(worker_port, f'pod0/{i}', 'FastSlowExecutor'), ) replica_processes.append(worker_process) worker_process.start() await asyncio.sleep(0.1) worker_ports.append(worker_port) await _activate_runtimes(head_port, worker_ports) # create a single gateway runtime gateway_process = multiprocessing.Process( target=_create_gateway_runtime, args=(graph_description, pod_addresses, port), ) gateway_process.start() await asyncio.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) c = Client(host='localhost', port=port, asyncio=True) input_docs = [Document(text='slow'), Document(text='fast')] responses = c.post('/', inputs=input_docs, request_size=1, return_responses=True) response_list = [] async for response in responses: response_list.append(response) # clean up runtimes gateway_process.terminate() head_process.terminate() for replica_process in replica_processes: replica_process.terminate() gateway_process.join() head_process.join() for replica_process in replica_processes: replica_process.join() assert len(response_list) == 2 for response in response_list: assert len(response.docs) == 1 assert response_list[0].docs[0].text == 'fast' assert response_list[1].docs[0].text == 'slow' assert gateway_process.exitcode == 0 assert head_process.exitcode == 0 for replica_process in replica_processes: assert replica_process.exitcode == 0
def test_custom_num_retries_headful(port_generator, retries, capfd): # create gateway and workers manually, then terminate worker process to provoke an error worker_port = port_generator() gateway_port = port_generator() head_port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' connection_list_dict = {'0': [f'127.0.0.1:{worker_port}']} head_process = _create_head(head_port, connection_list_dict, 'ANY', retries=retries) worker_process = _create_worker(worker_port) gateway_process = _create_gateway( gateway_port, graph_description, pod_addresses, 'grpc', retries=retries ) time.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{head_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{worker_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{gateway_port}', ready_or_shutdown_event=multiprocessing.Event(), ) try: # ----------- 1. ping Flow once to trigger endpoint discovery ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process(target=_send_request, args=(gateway_port, 'grpc')) p.start() p.join() assert p.exitcode == 0 # kill worker worker_process.terminate() worker_process.join() # ----------- 2. test that call will be retried the appropriate number of times ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process( target=_test_custom_retry, args=(gateway_port, worker_port, 'grpc', retries, capfd), ) p.start() p.join() assert p.exitcode == 0 except Exception: assert False finally: # clean up runtimes gateway_process.terminate() gateway_process.join() worker_process.terminate() worker_process.join() head_process.terminate() head_process.join()
async def test_runtimes_with_executor(port_generator): graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' runtime_processes = [] uses_before_port, uses_before_process = await _create_worker( 'pod0', port_generator, type='uses_before', executor='NameChangeExecutor') runtime_processes.append(uses_before_process) uses_after_port, uses_after_process = await _create_worker( 'pod0', port_generator, type='uses_after', executor='NameChangeExecutor') runtime_processes.append(uses_after_process) # create head head_port = port_generator() pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' head_process = multiprocessing.Process( target=_create_head_runtime, args=( head_port, f'pod0/head', 'ALL', f'127.0.0.1:{uses_before_port}', f'127.0.0.1:{uses_after_port}', ), ) runtime_processes.append(head_process) head_process.start() runtime_processes.append(head_process) # create some shards worker_ports = [] for i in range(10): # create worker worker_port, worker_process = await _create_worker( 'pod0', port_generator, type=f'shards/{i}', executor='NameChangeExecutor') runtime_processes.append(worker_process) await asyncio.sleep(0.1) worker_ports.append(worker_port) await _activate_runtimes(head_port, worker_ports) # create a single gateway runtime port = port_generator() gateway_process = multiprocessing.Process( target=_create_gateway_runtime, args=(graph_description, pod_addresses, port), ) gateway_process.start() runtime_processes.append(gateway_process) await asyncio.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) c = Client(host='localhost', port=port, asyncio=True) responses = c.post('/', inputs=async_inputs, request_size=1, return_responses=True) response_list = [] async for response in responses: response_list.append(response.docs) # clean up runtimes for process in runtime_processes: process.terminate() for process in runtime_processes: process.join() assert len(response_list) == 20 assert ( len(response_list[0]) == (1 + 1 + 1) * 10 + 1 ) # 1 starting doc + 1 uses_before + every exec adds 1 * 10 shards + 1 doc uses_after doc_texts = [doc.text for doc in response_list[0]] assert doc_texts.count('client0-Request') == 10 assert doc_texts.count('pod0/uses_before') == 10 assert doc_texts.count('pod0/uses_after') == 1 for i in range(10): assert doc_texts.count(f'pod0/shards/{i}') == 1
async def test_runtimes_headless_topology( port_generator, protocol, fail_before_endpoint_discovery ): # create gateway and workers manually, then terminate worker process to provoke an error worker_port = port_generator() gateway_port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{worker_port}"]}}' worker_process = _create_worker(worker_port) gateway_process = _create_gateway( gateway_port, graph_description, pod_addresses, protocol ) time.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{worker_port}', ready_or_shutdown_event=multiprocessing.Event(), ) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{gateway_port}', ready_or_shutdown_event=multiprocessing.Event(), ) if ( fail_before_endpoint_discovery ): # kill worker before having sent the first request, so before endpoint discov. worker_process.terminate() worker_process.join() try: if fail_before_endpoint_discovery: # here worker is already dead before the first request, so endpoint discovery will fail # ----------- 1. test that useful errors are given when endpoint discovery fails ----------- # we have to do this in a new process because otherwise grpc will be sad and everything will crash :( p = multiprocessing.Process( target=_test_error, args=(gateway_port, worker_port, protocol) ) p.start() p.join() assert ( p.exitcode == 0 ) # if exitcode != 0 then test in other process did not pass and this should fail else: # just ping the Flow without having killed a worker before. This (also) performs endpoint discovery p = multiprocessing.Process( target=_send_request, args=(gateway_port, protocol) ) p.start() p.join() # only now do we kill the worker, after having performed successful endpoint discovery # so in this case, the actual request will fail, not the discovery, which is handled differently by Gateway worker_process.terminate() # kill worker worker_process.join() assert not worker_process.is_alive() # ----------- 2. test that gateways remain alive ----------- # just do the same again, expecting the same failure p = multiprocessing.Process( target=_test_error, args=(gateway_port, worker_port, protocol) ) p.start() p.join() assert ( p.exitcode == 0 ) # if exitcode != 0 then test in other process did not pass and this should fail except Exception: assert False finally: # clean up runtimes gateway_process.terminate() worker_process.terminate() gateway_process.join() worker_process.join()
async def test_runtimes_shards(polling, port_generator): head_port = port_generator() port = port_generator() graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}' pod_addresses = f'{{"pod0": ["0.0.0.0:{head_port}"]}}' # create a single head runtime head_process = multiprocessing.Process(target=_create_head_runtime, args=(head_port, 'head', polling)) head_process.start() # create the shards shard_processes = [] worker_ports = [] for i in range(10): # create worker worker_port = port_generator() # create a single worker runtime worker_process = multiprocessing.Process(target=_create_worker_runtime, args=(worker_port, f'pod0/shard/{i}')) shard_processes.append(worker_process) worker_process.start() await asyncio.sleep(0.1) worker_ports.append(worker_port) await _activate_runtimes(head_port, worker_ports) # create a single gateway runtime gateway_process = multiprocessing.Process( target=_create_gateway_runtime, args=(graph_description, pod_addresses, port), ) gateway_process.start() await asyncio.sleep(1.0) AsyncNewLoopRuntime.wait_for_ready_or_shutdown( timeout=5.0, ctrl_address=f'0.0.0.0:{port}', ready_or_shutdown_event=multiprocessing.Event(), ) c = Client(host='localhost', port=port, asyncio=True) responses = c.post('/', inputs=async_inputs, request_size=1, return_responses=True) response_list = [] async for response in responses: response_list.append(response) # clean up runtimes gateway_process.terminate() head_process.terminate() for shard_process in shard_processes: shard_process.terminate() gateway_process.join() head_process.join() for shard_process in shard_processes: shard_process.join() assert len(response_list) == 20 assert len(response_list[0].docs) == 1 if polling == 'ANY' else len( shard_processes) assert gateway_process.exitcode == 0 assert head_process.exitcode == 0 for shard_process in shard_processes: assert shard_process.exitcode == 0