def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = executor_stacks.local_executor_factory( num_clients=3).create_executor({}) tracer = executor_test_utils.TracingExecutor(target_executor) service = executor_service.ExecutorService(tracer) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = reference_resolving_executor.ReferenceResolvingExecutor( remote_exec) try: yield collections.namedtuple('_', 'executor tracer')(executor, tracer) finally: executor.close() tracer.close() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = eager_executor.EagerExecutor() tracer = executor_test_utils.TracingExecutor(target_executor) service = executor_service.ExecutorService(tracer) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = lambda_executor.LambdaExecutor(remote_exec) set_default_executor.set_default_executor(executor) try: yield collections.namedtuple('_', 'executor tracer')(executor, tracer) finally: remote_exec.__del__() set_default_executor.set_default_executor() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def main(argv): del argv tf.compat.v1.enable_v2_behavior() # TODO(b/134543154): Replace this with the complete local executor stack. executor = tff.framework.EagerExecutor() service = tff.framework.ExecutorService(executor) server = grpc.server( concurrent.futures.ThreadPoolExecutor(max_workers=FLAGS.threads)) with open(FLAGS.private_key, 'rb') as f: private_key = f.read() with open(FLAGS.certificate_chain, 'rb') as f: certificate_chain = f.read() server_creds = grpc.ssl_server_credentials((( private_key, certificate_chain, ), )) server.add_secure_port(FLAGS.endpoint, server_creds) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(None)
def server_context(ex_factory: executor_factory.ExecutorFactory, num_threads: int, port: int, credentials: Optional[grpc.ServerCredentials] = None, options: Optional[List[Tuple[Any, Any]]] = None): """Context manager yielding gRPC server hosting simulation component. Args: ex_factory: The executor factory to be hosted by the server. num_threads: The number of network threads to use for handling gRPC calls. port: The port to listen on (for gRPC), must be a non-zero integer. credentials: The optional credentials to use for the secure connection if any, or `None` if the server should open an insecure port. If specified, must be a valid `ServerCredentials` object that can be accepted by the gRPC server's `add_secure_port()`. options: The optional `list` of server options, each in the `(key, value)` format accepted by the `grpc.server()` constructor. Yields: The constructed gRPC server. Raises: ValueError: If `num_threads` or `port` are invalid. """ py_typecheck.check_type(ex_factory, executor_factory.ExecutorFactory) py_typecheck.check_type(num_threads, int) py_typecheck.check_type(port, int) if credentials is not None: py_typecheck.check_type(credentials, grpc.ServerCredentials) if num_threads < 1: raise ValueError('The number of threads must be a positive integer.') if port < 1: raise ValueError('The server port must be a positive integer.') try: service = executor_service.ExecutorService(ex_factory) server_kwargs = {} if options is not None: server_kwargs['options'] = options thread_pool_executor = concurrent.futures.ThreadPoolExecutor( max_workers=num_threads) server = grpc.server(thread_pool_executor, **server_kwargs) full_port_string = '[::]:{}'.format(port) if credentials is not None: server.add_secure_port(full_port_string, credentials) else: server.add_insecure_port(full_port_string) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() yield server except KeyboardInterrupt: logging.info('Server stopped by KeyboardInterrupt.') finally: logging.info('Shutting down server.') thread_pool_executor.shutdown(wait=False) server.stop(None) ex_factory.clean_up_executors()
def __init__(self, executor): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) self._server = grpc.server(server_pool) self._server.add_insecure_port('[::]:{}'.format(port)) self._service = executor_service.ExecutorService(executor) executor_pb2_grpc.add_ExecutorServicer_to_server( self._service, self._server) self._server.start() self._channel = grpc.insecure_channel('localhost:{}'.format(port)) self._stub = executor_pb2_grpc.ExecutorStub(self._channel)
def setUp(self): super(ExecutorServiceTest, self).setUp() port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) self._server = grpc.server(server_pool) self._server.add_insecure_port('[::]:{}'.format(port)) self._service = executor_service.ExecutorService( eager_executor.EagerExecutor()) executor_pb2_grpc.add_ExecutorServicer_to_server( self._service, self._server) self._server.start() self._channel = grpc.insecure_channel('localhost:{}'.format(port)) self._stub = executor_pb2_grpc.ExecutorStub(self._channel)
def run_server(executor, num_threads, port, credentials=None, options=None): """Runs a gRPC server hosting a simulation component in this process. The server runs indefinitely, but can be stopped by a keyboard interrrupt. Args: executor: The executor to be hosted by the server. num_threads: The number of network threads to use for handling gRPC calls. port: The port to listen on (for gRPC), must be a non-zero integer. credentials: The optional credentials to use for the secure connection if any, or `None` if the server should open an insecure port. If specified, must be a valid `ServerCredentials` object that can be accepted by the gRPC server's `add_secure_port()`. options: The optional `list` of server options, each in the `(key, value)` format accepted by the `grpc.server()` constructor. Raises: ValueError: If `num_threads` or `port` are invalid. """ py_typecheck.check_type(executor, framework.Executor) py_typecheck.check_type(num_threads, int) py_typecheck.check_type(port, int) if credentials is not None: py_typecheck.check_type(credentials, grpc.ServerCredentials) if num_threads < 1: raise ValueError('The number of threads must be a positive integer.') if port < 1: raise ValueError('The server port must be a positive integer.') service = framework.ExecutorService(executor) server_kwargs = {} if options is not None: server_kwargs['options'] = options server = grpc.server( concurrent.futures.ThreadPoolExecutor(max_workers=num_threads), **server_kwargs) full_port_string = '[::]:{}'.format(port) if credentials is not None: server.add_secure_port(full_port_string, credentials) else: server.add_insecure_port(full_port_string) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(None)
def __init__(self, ex_factory: executor_factory.ExecutorFactory, num_clients: int = 0): port = portpicker.pick_unused_port() self._server_pool = logging_pool.pool(max_workers=1) self._server = grpc.server(self._server_pool) self._server.add_insecure_port('[::]:{}'.format(port)) self._service = executor_service.ExecutorService(ex_factory=ex_factory) executor_pb2_grpc.add_ExecutorServicer_to_server(self._service, self._server) self._server.start() self._channel = grpc.insecure_channel('localhost:{}'.format(port)) self._stub = executor_pb2_grpc.ExecutorStub(self._channel) serialized_cards = executor_serialization.serialize_cardinalities( {placement_literals.CLIENTS: num_clients}) self._stub.SetCardinalities( executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards))
def test_context(): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_executor = eager_executor.EagerExecutor() service = executor_service.ExecutorService(target_executor) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) executor = remote_executor.RemoteExecutor(channel) set_default_executor.set_default_executor(executor) yield executor set_default_executor.set_default_executor() try: channel.close() except AttributeError: del channel server.stop(None)
def test_context(rpc_mode='REQUEST_REPLY'): port = portpicker.pick_unused_port() server_pool = logging_pool.pool(max_workers=1) server = grpc.server(server_pool) server.add_insecure_port('[::]:{}'.format(port)) target_factory = executor_stacks.local_executor_factory(num_clients=3) tracers = [] def _tracer_fn(cardinalities): tracer = executor_test_utils.TracingExecutor( target_factory.create_executor(cardinalities)) tracers.append(tracer) return tracer service = executor_service.ExecutorService( executor_stacks.ResourceManagingExecutorFactory(_tracer_fn)) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() channel = grpc.insecure_channel('localhost:{}'.format(port)) stub = executor_pb2_grpc.ExecutorStub(channel) serialized_cards = executor_service_utils.serialize_cardinalities( {placement_literals.CLIENTS: 3}) stub.SetCardinalities( executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards)) remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode) executor = reference_resolving_executor.ReferenceResolvingExecutor( remote_exec) try: yield collections.namedtuple('_', 'executor tracers')(executor, tracers) finally: executor.close() for tracer in tracers: tracer.close() try: channel.close() except AttributeError: pass # Public gRPC channel doesn't support close() finally: server.stop(None)
def main(argv): del argv tf.compat.v1.enable_v2_behavior() service = tff.framework.ExecutorService( tff.framework.LambdaExecutor( tff.framework.ConcurrentExecutor(tff.framework.EagerExecutor()))) server = grpc.server( concurrent.futures.ThreadPoolExecutor(max_workers=FLAGS.threads)) if FLAGS.private_key: if FLAGS.certificate_chain: with open(FLAGS.private_key, 'rb') as f: private_key = f.read() with open(FLAGS.certificate_chain, 'rb') as f: certificate_chain = f.read() server_creds = grpc.ssl_server_credentials((( private_key, certificate_chain, ), )) else: raise ValueError( 'Private key has been specified, but the certificate chain missing.' ) else: server_creds = None full_port_string = '[::]:{}'.format(str(FLAGS.port)) if server_creds is not None: server.add_secure_port(full_port_string, server_creds) else: server.add_insecure_port(full_port_string) executor_pb2_grpc.add_ExecutorServicer_to_server(service, server) server.start() try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(None)