Example #1
0
 def test_serialize_deserialize_clients_alone(self):
     client_cardinalities = {placement_literals.CLIENTS: 10}
     cardinalities_list = executor_service_utils.serialize_cardinalities(
         client_cardinalities)
     for cardinality in cardinalities_list:
         self.assertIsInstance(
             cardinality, executor_pb2.SetCardinalitiesRequest.Cardinality)
     reconstructed_cardinalities = executor_service_utils.deserialize_cardinalities(
         cardinalities_list)
     self.assertEqual(client_cardinalities, reconstructed_cardinalities)
  async def set_cardinalities(
      self, cardinalities: Mapping[placement_literals.PlacementLiteral, int]):
    serialized_cardinalities = executor_service_utils.serialize_cardinalities(
        cardinalities)
    request = executor_pb2.SetCardinalitiesRequest(
        cardinalities=serialized_cardinalities)

    if self._bidi_stream is None:
      _request(self._stub.SetCardinalities, request)
    else:
      await self._bidi_stream.send_request(
          executor_pb2.ExecuteRequest(set_cardinalities=request))
    return
Example #3
0
  def __init__(self,
               ex_factory: executor_factory.ExecutorFactory,
               num_clients: int = 0):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    self._server = grpc.server(server_pool)
    self._server.add_insecure_port('[::]:{}'.format(port))
    self._service = executor_service.ExecutorService(ex_factory=ex_factory)
    executor_pb2_grpc.add_ExecutorServicer_to_server(self._service,
                                                     self._server)
    self._server.start()
    self._channel = grpc.insecure_channel('localhost:{}'.format(port))
    self._stub = executor_pb2_grpc.ExecutorStub(self._channel)

    serialized_cards = executor_service_utils.serialize_cardinalities(
        {placement_literals.CLIENTS: num_clients})
    self._stub.SetCardinalities(
        executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards))
Example #4
0
def test_context(rpc_mode='REQUEST_REPLY'):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_factory = executor_stacks.local_executor_factory(num_clients=3)
    tracers = []

    def _tracer_fn(cardinalities):
        tracer = executor_test_utils.TracingExecutor(
            target_factory.create_executor(cardinalities))
        tracers.append(tracer)
        return tracer

    service = executor_service.ExecutorService(
        executor_stacks.ResourceManagingExecutorFactory(_tracer_fn))
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()

    channel = grpc.insecure_channel('localhost:{}'.format(port))
    stub = executor_pb2_grpc.ExecutorStub(channel)
    serialized_cards = executor_service_utils.serialize_cardinalities(
        {placement_literals.CLIENTS: 3})
    stub.SetCardinalities(
        executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards))

    remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracers')(executor,
                                                              tracers)
    finally:
        executor.close()
        for tracer in tracers:
            tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)