def real_jail_session(token): """ Create a Jail service for testing, using TCP sockets, uses the token for auth """ auth_interceptor = Auth().get_auth_interceptor(allow_jailed=True) with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor, interceptors=[auth_interceptor]) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) servicer = Jail() jail_pb2_grpc.add_JailServicer_to_server(servicer, server) server.start() call_creds = grpc.metadata_call_credentials( CookieMetadataPlugin(token)) comp_creds = grpc.composite_channel_credentials( grpc.local_channel_credentials(), call_creds) try: with grpc.secure_channel(f"localhost:{port}", comp_creds) as channel: yield jail_pb2_grpc.JailStub(channel) finally: server.stop(None).wait()
def interceptor_dummy_api( rpc, interceptors, service_name="testing.Test", method_name="TestRpc", request_type=empty_pb2.Empty, response_type=empty_pb2.Empty, creds=None, ): with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor, interceptors=interceptors) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) # manually add the handler rpc_method_handlers = { method_name: grpc.unary_unary_rpc_method_handler( rpc, request_deserializer=request_type.FromString, response_serializer=response_type.SerializeToString, ) } generic_handler = grpc.method_handlers_generic_handler(service_name, rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) server.start() try: with grpc.secure_channel(f"localhost:{port}", creds or grpc.local_channel_credentials()) as channel: call_rpc = channel.unary_unary( f"/{service_name}/{method_name}", request_serializer=request_type.SerializeToString, response_deserializer=response_type.FromString, ) yield call_rpc finally: server.stop(None).wait()
async def start_test_server(secure=False): server = aio.server(options=(('grpc.so_reuseport', 0),)) servicer = _TestServiceServicer() test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) # Add programatically extra methods not provided by the proto file # that are used during the tests rpc_method_handlers = { 'UnaryCallWithSleep': grpc.unary_unary_rpc_method_handler( servicer.UnaryCallWithSleep, request_deserializer=messages_pb2.SimpleRequest.FromString, response_serializer=messages_pb2.SimpleResponse. SerializeToString) } extra_handler = grpc.method_handlers_generic_handler( 'grpc.testing.TestService', rpc_method_handlers) server.add_generic_rpc_handlers((extra_handler,)) if secure: server_credentials = grpc.local_server_credentials( grpc.LocalConnectionType.LOCAL_TCP) port = server.add_secure_port('[::]:0', server_credentials) else: port = server.add_insecure_port('[::]:0') await server.start() # NOTE(lidizheng) returning the server to prevent it from deallocation return 'localhost:%d' % port, server
def setUp(self): super().setUp() # Set up MLMD connection. pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) with self._mlmd_connection as m: self._execution = execution_publish_utils.register_execution( metadata_handler=m, execution_type=metadata_store_pb2.ExecutionType( name='test_execution_type'), contexts=[], input_artifacts=[]) # Set up gRPC stub. port = portpicker.pick_unused_port() self.sidecar = execution_watcher.ExecutionWatcher( port, mlmd_connection=self._mlmd_connection, execution=self._execution, creds=grpc.local_server_credentials()) self.sidecar.start() self.stub = execution_watcher.generate_service_stub( self.sidecar.address, grpc.local_channel_credentials())
def main(unused_argv): aead.register() daead.register() hybrid.register() mac.register() prf.register() signature.register() streaming_aead.register() jwt.register_jwt_mac() fake_kms.register_client() server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) testing_api_pb2_grpc.add_MetadataServicer_to_server( services.MetadataServicer(), server) testing_api_pb2_grpc.add_KeysetServicer_to_server( services.KeysetServicer(), server) testing_api_pb2_grpc.add_AeadServicer_to_server(services.AeadServicer(), server) testing_api_pb2_grpc.add_DeterministicAeadServicer_to_server( services.DeterministicAeadServicer(), server) testing_api_pb2_grpc.add_MacServicer_to_server(services.MacServicer(), server) testing_api_pb2_grpc.add_PrfSetServicer_to_server( services.PrfSetServicer(), server) testing_api_pb2_grpc.add_HybridServicer_to_server( services.HybridServicer(), server) testing_api_pb2_grpc.add_SignatureServicer_to_server( services.SignatureServicer(), server) testing_api_pb2_grpc.add_StreamingAeadServicer_to_server( services.StreamingAeadServicer(), server) testing_api_pb2_grpc.add_JwtServicer_to_server(jwt_service.JwtServicer(), server) server.add_secure_port('[::]:%d' % FLAGS.port, grpc.local_server_credentials()) server.start() server.wait_for_termination()
def auth_api_session(): """ Create an Auth API for testing This needs to use the real server since it plays around with headers """ with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) auth_pb2_grpc.add_AuthServicer_to_server(Auth(), server) server.start() try: with grpc.secure_channel( f"localhost:{port}", grpc.local_channel_credentials()) as channel: class _MetadataKeeperInterceptor( grpc.UnaryUnaryClientInterceptor): def __init__(self): self.latest_headers = {} def intercept_unary_unary(self, continuation, client_call_details, request): call = continuation(client_call_details, request) self.latest_headers = dict(call.initial_metadata()) return call metadata_interceptor = _MetadataKeeperInterceptor() channel = grpc.intercept_channel(channel, metadata_interceptor) yield auth_pb2_grpc.AuthStub(channel), metadata_interceptor finally: server.stop(None).wait()
def error_sanitizing_interceptor_dummy_api(TestRpc): with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor, interceptors=[ErrorSanitizationInterceptor()]) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) # manually add the handler rpc_method_handlers = { "TestRpc": grpc.unary_unary_rpc_method_handler( TestRpc, request_deserializer=empty_pb2.Empty.FromString, response_serializer=empty_pb2.Empty.SerializeToString, ) } generic_handler = grpc.method_handlers_generic_handler( "testing.Test", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler, )) server.start() try: with grpc.secure_channel( f"localhost:{port}", grpc.local_channel_credentials()) as channel: call_rpc = channel.unary_unary( "/testing.Test/TestRpc", request_serializer=empty_pb2.Empty.SerializeToString, response_deserializer=empty_pb2.Empty.FromString, ) yield call_rpc finally: server.stop(None).wait()
def test_unary_unary_secure(self): with _server(grpc.local_server_credentials()) as port: target = f'localhost:{port}' response = grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, channel_credentials=grpc.local_channel_credentials()) self.assertEqual(_REQUEST, response)
def main(unused_argv): aead.register() server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) testing_api_pb2_grpc.add_KeysetServicer_to_server(KeysetServicer(), server) testing_api_pb2_grpc.add_AeadServicer_to_server(AeadServicer(), server) server.add_secure_port('[::]:%d' % FLAGS.port, grpc.local_server_credentials()) server.start() server.wait_for_termination()
def test_unary_stream(self): with _server(grpc.local_server_credentials()) as port: target = f'localhost:{port}' for response in grpc.experimental.unary_stream( _REQUEST, target, _UNARY_STREAM, channel_credentials=grpc.local_channel_credentials()): self.assertEqual(_REQUEST, response)
def _start_server(): """Starts the Catch gRPC server.""" server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) servicer = catch_environment.CatchEnvironmentService() dm_env_rpc_pb2_grpc.add_EnvironmentServicer_to_server(servicer, server) port = server.add_secure_port('localhost:0', grpc.local_server_credentials()) server.start() return server, port
def test_create_secure_channel_and_connect(self): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) port = server.add_secure_port('[::]:0', grpc.local_server_credentials()) server.start() self.assertIsNotNone( dm_env_rpc_connection.create_secure_channel_and_connect( f'[::]:{port}', grpc.local_channel_credentials())) server.stop(grace=None)
def main(_): logging.info('Loading server...') server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) environment_pb2_grpc.add_EnvironmentServiceServicer_to_server( EnvironmentServicer(), server) server_creds = grpc.local_server_credentials() server.add_secure_port('[::]:{}'.format(FLAGS.port), server_creds) server.start() logging.info('Running server on port %s...', FLAGS.port) server.wait_for_termination()
def mock_main_server(*args, **kwargs): server = grpc.server(futures.ThreadPoolExecutor(1)) port = server.add_secure_port("localhost:8088", grpc.local_server_credentials()) servicer = MockMainServer(*args, **kwargs) media_pb2_grpc.add_MediaServicer_to_server(servicer, server) server.start() try: yield port finally: server.stop(None).wait()
def test_stream_stream(self): def request_iter(): for _ in range(_CLIENT_REQUEST_COUNT): yield _REQUEST with _server(grpc.local_server_credentials()) as port: target = f'localhost:{port}' for response in grpc.experimental.stream_stream( request_iter(), target, _STREAM_STREAM, channel_credentials=grpc.local_channel_credentials()): self.assertEqual(_REQUEST, response)
def test_channels_cached(self): with _server(grpc.local_server_credentials()) as port: target = f'localhost:{port}' test_name = inspect.stack()[0][3] args = (_REQUEST, target, _UNARY_UNARY) kwargs = {"channel_credentials": grpc.local_channel_credentials()} def _invoke(seed: str): run_kwargs = dict(kwargs) run_kwargs["options"] = ((test_name + seed, ""), ) grpc.experimental.unary_unary(*args, **run_kwargs) self.assert_cached(_invoke)
def testExecutionWatcher_LocalWithEmptyRequest(self): port = portpicker.pick_unused_port() sidecar = execution_watcher.ExecutionWatcher( port, creds=grpc.local_server_credentials()) sidecar.start() creds = grpc.local_channel_credentials() channel = grpc.secure_channel(sidecar.local_address, creds) stub = execution_watcher_pb2_grpc.ExecutionWatcherServiceStub(channel) req = execution_watcher_pb2.UpdateExecutionInfoRequest() res = stub.UpdateExecutionInfo(req) sidecar.stop() self.assertEqual(execution_watcher_pb2.UpdateExecutionInfoResponse(), res)
def test_channels_evicted(self): with _server(grpc.local_server_credentials()) as port: target = f'localhost:{port}' response = grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, channel_credentials=grpc.local_channel_credentials()) self.assert_eventually( lambda: grpc._simple_stubs.ChannelCache.get( )._test_only_channel_count() == 0, message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain" )
def __init__(self): self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) servicer = catch_environment.CatchEnvironmentService() dm_env_rpc_pb2_grpc.add_EnvironmentServicer_to_server( servicer, self._server) port = self._server.add_secure_port('[::]:0', grpc.local_server_credentials()) self._server.start() self._channel = grpc.secure_channel(f'[::]:{port}', grpc.local_channel_credentials()) grpc.channel_ready_future(self._channel).result() self.connection = dm_env_rpc_connection.Connection(self._channel)
def test_uds(self): server_addr = 'unix:/tmp/grpc_fullstack_test' channel_creds = grpc.local_channel_credentials( grpc.LocalConnectionType.UDS) server_creds = grpc.local_server_credentials( grpc.LocalConnectionType.UDS) server = self._create_server() server.add_secure_port(server_addr, server_creds) server.start() with grpc.secure_channel(server_addr, channel_creds) as channel: self.assertEqual(b'abc', channel.unary_unary('/test/method')( b'abc', wait_for_ready=True)) server.stop(None)
def test_local_tcp(self): server_addr = 'localhost:{}' channel_creds = grpc.local_channel_credentials( grpc.LocalConnectionType.LOCAL_TCP) server_creds = grpc.local_server_credentials( grpc.LocalConnectionType.LOCAL_TCP) server = self._create_server() port = server.add_secure_port(server_addr.format(0), server_creds) server.start() with grpc.secure_channel(server_addr.format(port), channel_creds) as channel: self.assertEqual(b'abc', channel.unary_unary('/test/method')( b'abc', wait_for_ready=True)) server.stop(None)
def run(self): """Context manager to run the gRPC server and yield a client for it.""" server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) grpc_util_test_pb2_grpc.add_TestServiceServicer_to_server(self, server) port = server.add_secure_port( "localhost:0", grpc.local_server_credentials()) def launch_server(): server.start() server.wait_for_termination() thread = threading.Thread(target=launch_server, name="TestGrpcServer") thread.daemon = True thread.start() with grpc.secure_channel( "localhost:%d" % port, grpc.local_channel_credentials()) as channel: yield grpc_util_test_pb2_grpc.TestServiceStub(channel) server.stop(grace=None) thread.join()
def setup_server(port=None): num_requests = min(2 * os.cpu_count(), MAX_CONCURRENT_REQUESTS) server = grpc.server( concurrent.futures.ThreadPoolExecutor(max_workers=2 * num_requests), maximum_concurrent_rpcs=num_requests, ) sampler_grpc.add_ExactTestSamplerServicer_to_server( ExactTestSampler(), server) listening_address = "localhost:0" if port is not None: listening_address = "localhost:%i" % port actual_port = server.add_secure_port(listening_address, grpc.local_server_credentials()) if actual_port == 0: raise Exception("Failed to open port %s for gRPC server." % port) return server, actual_port
def test_total_channels_enforced(self): with _server(grpc.local_server_credentials()) as port: target = f'localhost:{port}' for i in range(_STRESS_EPOCHS): # Ensure we get a new channel each time. options = (("foo", str(i)), ) # Send messages at full blast. grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, options=options, channel_credentials=grpc.local_channel_credentials()) self.assert_eventually( lambda: grpc._simple_stubs.ChannelCache.get( )._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1, message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain" )
async def start_test_server(port=0, secure=False, server_credentials=None): server = aio.server(options=(('grpc.so_reuseport', 0), )) servicer = _TestServiceServicer() test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) server.add_generic_rpc_handlers( (_create_extra_generic_handler(servicer), )) if secure: if server_credentials is None: server_credentials = grpc.local_server_credentials( grpc.LocalConnectionType.LOCAL_TCP) port = server.add_secure_port('[::]:%d' % port, server_credentials) else: port = server.add_insecure_port('[::]:%d' % port) await server.start() # NOTE(lidizheng) returning the server to prevent it from deallocation return 'localhost:%d' % port, server
def media_session(bearer_token): """ Create a fresh Media API for testing, uses the bearer token for media auth """ media_auth_interceptor = get_media_auth_interceptor(bearer_token) with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor, interceptors=[media_auth_interceptor]) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) servicer = Media() media_pb2_grpc.add_MediaServicer_to_server(servicer, server) server.start() call_creds = grpc.access_token_call_credentials(bearer_token) comp_creds = grpc.composite_channel_credentials(grpc.local_channel_credentials(), call_creds) try: with grpc.secure_channel(f"localhost:{port}", comp_creds) as channel: yield media_pb2_grpc.MediaStub(channel) finally: server.stop(None).wait()
def _create_server(self): self.server = None if self.address is not None: return from a3m.server.runner import create_server server_credentials = grpc.local_server_credentials( grpc.LocalConnectionType.LOCAL_TCP) self.server = create_server( self.BIND_LOCAL_ADDRESS, server_credentials, settings.CONCURRENT_PACKAGES, settings.BATCH_SIZE, settings.WORKER_THREADS, settings.RPC_THREADS, settings.DEBUG, ) # Compute address since port was dynamically assigned. self.address = f"localhost:{self.server.grpc_port}"
def main(): init_django() suppress_warnings() from a3m.server.runner import create_server logger.info( f"Starting a3m... (version={__version__} pid={os.getpid()} " f"uid={os.getuid()} python={platform.python_version()} " f"listen={settings.RPC_BIND_ADDRESS})" ) # A3M-TODO: make this configurable, e.g. local tcp, local uds, tls certs... # (see https://grpc.github.io/grpc/python/grpc.html#create-server-credentials for more) server_credentials = grpc.local_server_credentials( grpc.LocalConnectionType.LOCAL_TCP ) server = create_server( settings.RPC_BIND_ADDRESS, server_credentials, settings.CONCURRENT_PACKAGES, settings.BATCH_SIZE, settings.WORKER_THREADS, settings.RPC_THREADS, settings.DEBUG, ) server.start() def signal_handler(signo, frame): logger.info("Received termination signal (%s)", signal.Signals(signo).name) server.stop() signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) server.wait_for_termination() logger.info("a3m shutdown complete.")
def real_api_session(db, token): """ Create an API for testing, using TCP sockets, uses the token for auth """ auth_interceptor = Auth(db).get_auth_interceptor(allow_jailed=False) with futures.ThreadPoolExecutor(1) as executor: server = grpc.server(executor, interceptors=[auth_interceptor]) port = server.add_secure_port("localhost:0", grpc.local_server_credentials()) servicer = API(db) api_pb2_grpc.add_APIServicer_to_server(servicer, server) server.start() call_creds = grpc.access_token_call_credentials(token) comp_creds = grpc.composite_channel_credentials( grpc.local_channel_credentials(), call_creds) try: with grpc.secure_channel(f"localhost:{port}", comp_creds) as channel: yield api_pb2_grpc.APIStub(channel) finally: server.stop(None).wait()
def launch(self) -> Optional[data_types.ExecutionInfo]: """Executes the component, includes driver, executor and publisher. Returns: The metadata of this execution that is registered in MLMD. It can be None if the driver decides not to run the execution. Raises: Exception: If the executor fails. """ logging.info('Running launcher for %s', self._pipeline_node) if self._system_node_handler: # If this is a system node, runs it and directly return. return self._system_node_handler.run(self._mlmd_connection, self._pipeline_node, self._pipeline_info, self._pipeline_runtime_spec) # Runs as a normal node. execution_preparation_result = self._prepare_execution() (execution_info, contexts, is_execution_needed) = ( execution_preparation_result.execution_info, execution_preparation_result.contexts, execution_preparation_result.is_execution_needed) if is_execution_needed: try: executor_watcher = None if self._executor_operator: # Create an execution watcher and save an in memory copy of the # Execution object to execution to it. Launcher calls executor # operator in process, thus there won't be race condition between the # execution watcher and the launcher to write to MLMD. executor_watcher = execution_watcher.ExecutionWatcher( port=portpicker.pick_unused_port(), mlmd_connection=self._mlmd_connection, execution=execution_preparation_result. execution_metadata, creds=grpc.local_server_credentials()) self._executor_operator.with_execution_watcher( executor_watcher.address) executor_watcher.start() executor_output = self._run_executor(execution_info) except Exception as e: # pylint: disable=broad-except execution_output = (e.executor_output if isinstance( e, _ExecutionFailedError) else None) self._publish_failed_execution(execution_info.execution_id, contexts, execution_output) logging.error('Execution %d failed.', execution_info.execution_id) raise finally: self._clean_up_stateless_execution_info(execution_info) if executor_watcher: executor_watcher.stop() logging.info('Execution %d succeeded.', execution_info.execution_id) self._clean_up_stateful_execution_info(execution_info) # TODO(b/182316162): Unify publisher handing so that post-execution # artifact logic is more cleanly handled. # Note that currently both the ExecutionInfo and ExecutorOutput are # consulted in `execution_publish_utils.publish_succeeded_execution()`. outputs_utils.tag_executor_output_with_version(executor_output) outputs_utils.tag_output_artifacts_with_version( execution_info.output_dict) logging.info('Publishing output artifacts %s for execution %s', execution_info.output_dict, execution_info.execution_id) self._publish_successful_execution(execution_info.execution_id, contexts, execution_info.output_dict, executor_output) return execution_info