def test_context(rpc_mode='REQUEST_REPLY'):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_executor = executor_stacks.local_executor_factory(
        num_clients=3).create_executor({})
    tracer = executor_test_utils.TracingExecutor(target_executor)
    service = executor_service.ExecutorService(tracer)
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()
    channel = grpc.insecure_channel('localhost:{}'.format(port))
    remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracer')(executor, tracer)
    finally:
        executor.close()
        tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)
Exemple #2
0
def test_context(rpc_mode='REQUEST_REPLY'):
  port = portpicker.pick_unused_port()
  server_pool = logging_pool.pool(max_workers=1)
  server = grpc.server(server_pool)
  server.add_insecure_port('[::]:{}'.format(port))
  target_executor = eager_executor.EagerExecutor()
  tracer = executor_test_utils.TracingExecutor(target_executor)
  service = executor_service.ExecutorService(tracer)
  executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
  server.start()
  channel = grpc.insecure_channel('localhost:{}'.format(port))
  remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
  executor = lambda_executor.LambdaExecutor(remote_exec)
  set_default_executor.set_default_executor(executor)
  try:
    yield collections.namedtuple('_', 'executor tracer')(executor, tracer)
  finally:
    remote_exec.__del__()
    set_default_executor.set_default_executor()
    try:
      channel.close()
    except AttributeError:
      pass  # Public gRPC channel doesn't support close()
    finally:
      server.stop(None)
Exemple #3
0
def main(argv):
    del argv
    tf.compat.v1.enable_v2_behavior()

    # TODO(b/134543154): Replace this with the complete local executor stack.
    executor = tff.framework.EagerExecutor()

    service = tff.framework.ExecutorService(executor)
    server = grpc.server(
        concurrent.futures.ThreadPoolExecutor(max_workers=FLAGS.threads))

    with open(FLAGS.private_key, 'rb') as f:
        private_key = f.read()
    with open(FLAGS.certificate_chain, 'rb') as f:
        certificate_chain = f.read()
    server_creds = grpc.ssl_server_credentials(((
        private_key,
        certificate_chain,
    ), ))

    server.add_secure_port(FLAGS.endpoint, server_creds)
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(None)
def server_context(ex_factory: executor_factory.ExecutorFactory,
                   num_threads: int,
                   port: int,
                   credentials: Optional[grpc.ServerCredentials] = None,
                   options: Optional[List[Tuple[Any, Any]]] = None):
    """Context manager yielding gRPC server hosting simulation component.

  Args:
    ex_factory: The executor factory to be hosted by the server.
    num_threads: The number of network threads to use for handling gRPC calls.
    port: The port to listen on (for gRPC), must be a non-zero integer.
    credentials: The optional credentials to use for the secure connection if
      any, or `None` if the server should open an insecure port. If specified,
      must be a valid `ServerCredentials` object that can be accepted by the
      gRPC server's `add_secure_port()`.
    options: The optional `list` of server options, each in the `(key, value)`
      format accepted by the `grpc.server()` constructor.

  Yields:
    The constructed gRPC server.

  Raises:
    ValueError: If `num_threads` or `port` are invalid.
  """
    py_typecheck.check_type(ex_factory, executor_factory.ExecutorFactory)
    py_typecheck.check_type(num_threads, int)
    py_typecheck.check_type(port, int)
    if credentials is not None:
        py_typecheck.check_type(credentials, grpc.ServerCredentials)
    if num_threads < 1:
        raise ValueError('The number of threads must be a positive integer.')
    if port < 1:
        raise ValueError('The server port must be a positive integer.')
    try:
        service = executor_service.ExecutorService(ex_factory)
        server_kwargs = {}
        if options is not None:
            server_kwargs['options'] = options
        thread_pool_executor = concurrent.futures.ThreadPoolExecutor(
            max_workers=num_threads)
        server = grpc.server(thread_pool_executor, **server_kwargs)
        full_port_string = '[::]:{}'.format(port)
        if credentials is not None:
            server.add_secure_port(full_port_string, credentials)
        else:
            server.add_insecure_port(full_port_string)
        executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
        server.start()
        yield server
    except KeyboardInterrupt:
        logging.info('Server stopped by KeyboardInterrupt.')
    finally:
        logging.info('Shutting down server.')
        thread_pool_executor.shutdown(wait=False)
        server.stop(None)
        ex_factory.clean_up_executors()
 def __init__(self, executor):
     port = portpicker.pick_unused_port()
     server_pool = logging_pool.pool(max_workers=1)
     self._server = grpc.server(server_pool)
     self._server.add_insecure_port('[::]:{}'.format(port))
     self._service = executor_service.ExecutorService(executor)
     executor_pb2_grpc.add_ExecutorServicer_to_server(
         self._service, self._server)
     self._server.start()
     self._channel = grpc.insecure_channel('localhost:{}'.format(port))
     self._stub = executor_pb2_grpc.ExecutorStub(self._channel)
Exemple #6
0
 def setUp(self):
     super(ExecutorServiceTest, self).setUp()
     port = portpicker.pick_unused_port()
     server_pool = logging_pool.pool(max_workers=1)
     self._server = grpc.server(server_pool)
     self._server.add_insecure_port('[::]:{}'.format(port))
     self._service = executor_service.ExecutorService(
         eager_executor.EagerExecutor())
     executor_pb2_grpc.add_ExecutorServicer_to_server(
         self._service, self._server)
     self._server.start()
     self._channel = grpc.insecure_channel('localhost:{}'.format(port))
     self._stub = executor_pb2_grpc.ExecutorStub(self._channel)
Exemple #7
0
def run_server(executor, num_threads, port, credentials=None, options=None):
    """Runs a gRPC server hosting a simulation component in this process.

  The server runs indefinitely, but can be stopped by a keyboard interrrupt.

  Args:
    executor: The executor to be hosted by the server.
    num_threads: The number of network threads to use for handling gRPC calls.
    port: The port to listen on (for gRPC), must be a non-zero integer.
    credentials: The optional credentials to use for the secure connection if
      any, or `None` if the server should open an insecure port. If specified,
      must be a valid `ServerCredentials` object that can be accepted by the
      gRPC server's `add_secure_port()`.
    options: The optional `list` of server options, each in the `(key, value)`
      format accepted by the `grpc.server()` constructor.

  Raises:
    ValueError: If `num_threads` or `port` are invalid.
  """
    py_typecheck.check_type(executor, framework.Executor)
    py_typecheck.check_type(num_threads, int)
    py_typecheck.check_type(port, int)
    if credentials is not None:
        py_typecheck.check_type(credentials, grpc.ServerCredentials)
    if num_threads < 1:
        raise ValueError('The number of threads must be a positive integer.')
    if port < 1:
        raise ValueError('The server port must be a positive integer.')
    service = framework.ExecutorService(executor)
    server_kwargs = {}
    if options is not None:
        server_kwargs['options'] = options
    server = grpc.server(
        concurrent.futures.ThreadPoolExecutor(max_workers=num_threads),
        **server_kwargs)
    full_port_string = '[::]:{}'.format(port)
    if credentials is not None:
        server.add_secure_port(full_port_string, credentials)
    else:
        server.add_insecure_port(full_port_string)
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()
    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(None)
  def __init__(self,
               ex_factory: executor_factory.ExecutorFactory,
               num_clients: int = 0):
    port = portpicker.pick_unused_port()
    self._server_pool = logging_pool.pool(max_workers=1)
    self._server = grpc.server(self._server_pool)
    self._server.add_insecure_port('[::]:{}'.format(port))
    self._service = executor_service.ExecutorService(ex_factory=ex_factory)
    executor_pb2_grpc.add_ExecutorServicer_to_server(self._service,
                                                     self._server)
    self._server.start()
    self._channel = grpc.insecure_channel('localhost:{}'.format(port))
    self._stub = executor_pb2_grpc.ExecutorStub(self._channel)

    serialized_cards = executor_serialization.serialize_cardinalities(
        {placement_literals.CLIENTS: num_clients})
    self._stub.SetCardinalities(
        executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards))
def test_context():
  port = portpicker.pick_unused_port()
  server_pool = logging_pool.pool(max_workers=1)
  server = grpc.server(server_pool)
  server.add_insecure_port('[::]:{}'.format(port))
  target_executor = eager_executor.EagerExecutor()
  service = executor_service.ExecutorService(target_executor)
  executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
  server.start()
  channel = grpc.insecure_channel('localhost:{}'.format(port))
  executor = remote_executor.RemoteExecutor(channel)
  set_default_executor.set_default_executor(executor)
  yield executor
  set_default_executor.set_default_executor()
  try:
    channel.close()
  except AttributeError:
    del channel
  server.stop(None)
Exemple #10
0
def test_context(rpc_mode='REQUEST_REPLY'):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_factory = executor_stacks.local_executor_factory(num_clients=3)
    tracers = []

    def _tracer_fn(cardinalities):
        tracer = executor_test_utils.TracingExecutor(
            target_factory.create_executor(cardinalities))
        tracers.append(tracer)
        return tracer

    service = executor_service.ExecutorService(
        executor_stacks.ResourceManagingExecutorFactory(_tracer_fn))
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()

    channel = grpc.insecure_channel('localhost:{}'.format(port))
    stub = executor_pb2_grpc.ExecutorStub(channel)
    serialized_cards = executor_service_utils.serialize_cardinalities(
        {placement_literals.CLIENTS: 3})
    stub.SetCardinalities(
        executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards))

    remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracers')(executor,
                                                              tracers)
    finally:
        executor.close()
        for tracer in tracers:
            tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)
Exemple #11
0
def main(argv):
    del argv
    tf.compat.v1.enable_v2_behavior()

    service = tff.framework.ExecutorService(
        tff.framework.LambdaExecutor(
            tff.framework.ConcurrentExecutor(tff.framework.EagerExecutor())))

    server = grpc.server(
        concurrent.futures.ThreadPoolExecutor(max_workers=FLAGS.threads))

    if FLAGS.private_key:
        if FLAGS.certificate_chain:
            with open(FLAGS.private_key, 'rb') as f:
                private_key = f.read()
            with open(FLAGS.certificate_chain, 'rb') as f:
                certificate_chain = f.read()
            server_creds = grpc.ssl_server_credentials(((
                private_key,
                certificate_chain,
            ), ))
        else:
            raise ValueError(
                'Private key has been specified, but the certificate chain missing.'
            )
    else:
        server_creds = None

    full_port_string = '[::]:{}'.format(str(FLAGS.port))
    if server_creds is not None:
        server.add_secure_port(full_port_string, server_creds)
    else:
        server.add_insecure_port(full_port_string)

    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(None)