Exemplo n.º 1
0
    def test_create_value_reraises_type_error(self, mock_stub):
        mock_stub.create_value = mock.Mock(side_effect=TypeError)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)

        with self.assertRaises(TypeError):
            asyncio.run(executor.create_value(1, tf.int32))
Exemplo n.º 2
0
def _configure_remote_workers(default_num_clients, stubs, thread_pool_executor,
                              dispose_batch_size):
    """"Configures `default_num_clients` across `remote_executors`."""
    available_stubs = [stub for stub in stubs if stub.is_ready]
    logging.info('%s TFF workers available out of a total of %s.',
                 len(available_stubs), len(stubs))
    if not available_stubs:
        raise executors_errors.RetryableError(
            'No workers are ready; try again to reconnect.')
    remaining_clients = default_num_clients
    live_workers = []
    for stub_idx, stub in enumerate(available_stubs):
        remaining_stubs = len(available_stubs) - stub_idx
        default_num_clients_to_host = remaining_clients // remaining_stubs
        remaining_clients -= default_num_clients_to_host
        if default_num_clients_to_host > 0:
            ex = remote_executor.RemoteExecutor(stub, thread_pool_executor,
                                                dispose_batch_size)
            ex.set_cardinalities(
                {placements.CLIENTS: default_num_clients_to_host})
            live_workers.append(ex)
    return [
        _wrap_executor_in_threading_stack(e, can_resolve_references=False)
        for e in live_workers
    ]
def test_context(rpc_mode='REQUEST_REPLY'):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_executor = executor_stacks.local_executor_factory(
        num_clients=3).create_executor({})
    tracer = executor_test_utils.TracingExecutor(target_executor)
    service = executor_service.ExecutorService(tracer)
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()
    channel = grpc.insecure_channel('localhost:{}'.format(port))
    remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracer')(executor, tracer)
    finally:
        executor.close()
        tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)
Exemplo n.º 4
0
 def test_set_cardinalities_returns_none(self, mock_stub):
     mock_stub.get_executor.return_value = executor_pb2.GetExecutorResponse(
         executor=executor_pb2.ExecutorId(id='test_id'))
     executor = remote_executor.RemoteExecutor(mock_stub)
     _set_cardinalities_with_mock(executor, mock_stub)
     result = executor.set_cardinalities({placements.CLIENTS: 3})
     self.assertIsNone(result)
Exemplo n.º 5
0
    def test_create_call_reraises_type_error(self, mock_stub):
        mock_stub.create_call = mock.Mock(side_effect=TypeError)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(TypeError):
            asyncio.run(executor.create_call(comp))
Exemplo n.º 6
0
    def test_create_value_reraises_grpc_error(self, mock_stub):
        mock_stub.create_value = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)

        with self.assertRaises(grpc.RpcError) as context:
            asyncio.run(executor.create_value(1, tf.int32))

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
Exemplo n.º 7
0
    def test_create_value_returns_remote_value(self, mock_stub):
        mock_stub.create_value.return_value = executor_pb2.CreateValueResponse(
        )
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)

        result = asyncio.run(executor.create_value(1, tf.int32))

        mock_stub.create_value.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
Exemplo n.º 8
0
    def test_create_selection_reraises_type_error(self, mock_stub):
        mock_stub.create_selection = mock.Mock(side_effect=TypeError)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        with self.assertRaises(TypeError):
            asyncio.run(executor.create_selection(source, 0))
Exemplo n.º 9
0
    def test_create_struct_reraises_type_error(self, mock_stub):
        mock_stub.create_struct = mock.Mock(side_effect=TypeError)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.TensorType(tf.int32)
        value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)
        value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)

        with self.assertRaises(TypeError):
            asyncio.run(executor.create_struct([value_1, value_2]))
Exemplo n.º 10
0
    def test_create_call_returns_remote_value(self, mock_stub):
        mock_stub.create_call.return_value = executor_pb2.CreateCallResponse()
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.FunctionType(None, tf.int32)
        fn = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

        result = asyncio.run(executor.create_call(fn, None))

        mock_stub.create_call.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
Exemplo n.º 11
0
    def test_create_call_reraises_grpc_error(self, mock_stub):
        mock_stub.create_call = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(grpc.RpcError) as context:
            asyncio.run(executor.create_call(comp, None))

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
Exemplo n.º 12
0
    def test_create_selection_returns_remote_value(self, mock_stub):
        mock_stub.create_selection.return_value = executor_pb2.CreateSelectionResponse(
        )
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        result = asyncio.run(executor.create_selection(source, 0))

        mock_stub.create_selection.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
Exemplo n.º 13
0
    def test_create_struct_returns_remote_value(self, mock_stub):
        mock_stub.create_struct.return_value = executor_pb2.CreateStructResponse(
        )
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.TensorType(tf.int32)
        value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)
        value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)

        result = asyncio.run(executor.create_struct([value_1, value_2]))

        mock_stub.create_struct.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
Exemplo n.º 14
0
    def test_compute_returns_result(self, mock_stub):
        tensor_proto = tf.make_tensor_proto(1)
        any_pb = any_pb2.Any()
        any_pb.Pack(tensor_proto)
        value = executor_pb2.Value(tensor=any_pb)
        mock_stub.compute.return_value = executor_pb2.ComputeResponse(
            value=value)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        executor.set_cardinalities({placements.CLIENTS: 3})
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        result = asyncio.run(comp.compute())

        mock_stub.compute.assert_called_once()
        self.assertEqual(result, 1)
Exemplo n.º 15
0
def test_context(rpc_mode='REQUEST_REPLY'):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_factory = executor_stacks.local_executor_factory(num_clients=3)
    tracers = []

    def _tracer_fn(cardinalities):
        tracer = executor_test_utils.TracingExecutor(
            target_factory.create_executor(cardinalities))
        tracers.append(tracer)
        return tracer

    service = executor_service.ExecutorService(
        executor_stacks.ResourceManagingExecutorFactory(_tracer_fn))
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()

    channel = grpc.insecure_channel('localhost:{}'.format(port))
    stub = executor_pb2_grpc.ExecutorStub(channel)
    serialized_cards = executor_service_utils.serialize_cardinalities(
        {placement_literals.CLIENTS: 3})
    stub.SetCardinalities(
        executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards))

    remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracers')(executor,
                                                              tracers)
    finally:
        executor.close()
        for tracer in tracers:
            tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)
Exemplo n.º 16
0
def create_remote_execution_context(channels,
                                    rpc_mode='REQUEST_REPLY',
                                    thread_pool_executor=None,
                                    dispose_batch_size=20,
                                    max_fanout: int = 100):
  """Creates context to execute computations using remote workers on `channels`."""
  # TODO(b/166634524): Reparameterize worker_pool_executor_factory to
  # construct remote executors, rename to remote_executor_factory or something
  # similar.
  executors = [
      remote_executor.RemoteExecutor(channel, rpc_mode, thread_pool_executor,
                                     dispose_batch_size) for channel in channels
  ]
  factory = executor_stacks.worker_pool_executor_factory(
      executors=executors,
      max_fanout=max_fanout,
  )

  return execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
Exemplo n.º 17
0
def test_context():
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_factory = executor_test_utils.LocalTestExecutorFactory(
        default_num_clients=3)
    tracers = []

    def _tracer_fn(cardinalities):
        tracer = executor_test_utils.TracingExecutor(
            target_factory.create_executor(cardinalities))
        tracers.append(tracer)
        return tracer

    service = executor_service.ExecutorService(
        executor_test_utils.BasicTestExFactory(_tracer_fn))
    executor_pb2_grpc.add_ExecutorGroupServicer_to_server(service, server)
    server.start()

    channel = grpc.insecure_channel('localhost:{}'.format(port))

    stub = remote_executor_grpc_stub.RemoteExecutorGrpcStub(channel)
    remote_exec = remote_executor.RemoteExecutor(stub)
    remote_exec.set_cardinalities({placements.CLIENTS: 3})
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracers')(executor,
                                                              tracers)
    finally:
        executor.close()
        for tracer in tracers:
            tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)
Exemplo n.º 18
0
def remote_executor_factory(channels,
                            rpc_mode='REQUEST_REPLY',
                            thread_pool_executor=None,
                            dispose_batch_size=20,
                            max_fanout=100) -> executor_factory.ExecutorFactory:
  """Create an executor backed by remote workers.

  Args:
    channels: A list of `grpc.Channels` hosting services which can execute TFF
      work.
    rpc_mode: A string specifying the connection mode between the local host and
      `channels`.
    thread_pool_executor: Optional concurrent.futures.Executor used to wait for
      the reply to a streaming RPC message. Uses the default Executor if not
      specified.
    dispose_batch_size: The batch size for requests to dispose of remote worker
      values. Lower values will result in more requests to the remote worker,
      but will result in values being cleaned up sooner and therefore may result
      in lower memory usage on the remote worker.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.
  """
  py_typecheck.check_type(channels, list)
  if not channels:
    raise ValueError('The list of channels cannot be empty.')

  remote_executors = [
      remote_executor.RemoteExecutor(channel, rpc_mode, thread_pool_executor,
                                     dispose_batch_size) for channel in channels
  ]

  def _configure_remote_executor(ex, cardinalities, loop):
    """Configures `ex` to run the appropriate number of clients."""
    loop.run_until_complete(ex.set_cardinalities(cardinalities))
    return

  def _configure_remote_workers(cardinalities):
    loop = asyncio.new_event_loop()
    try:
      if not cardinalities.get(placement_literals.CLIENTS):
        for ex in remote_executors:
          _configure_remote_executor(ex, cardinalities, loop)
        return [_wrap_executor_in_threading_stack(e) for e in remote_executors]

      remaining_clients = cardinalities[placement_literals.CLIENTS]
      clients_per_most_executors = remaining_clients // len(remote_executors)

      for ex in remote_executors[:-1]:
        _configure_remote_executor(
            ex, {placement_literals.CLIENTS: clients_per_most_executors}, loop)
        remaining_clients -= clients_per_most_executors
      _configure_remote_executor(
          remote_executors[-1], {placement_literals.CLIENTS: remaining_clients},
          loop)
    finally:
      loop.close()
    return [_wrap_executor_in_threading_stack(e) for e in remote_executors]

  flat_stack_fn = _configure_remote_workers
  unplaced_ex_factory = UnplacedExecutorFactory(use_caching=False)
  composing_executor_factory = ComposingExecutorFactory(
      max_fanout=max_fanout,
      unplaced_ex_factory=unplaced_ex_factory,
      flat_stack_fn=flat_stack_fn,
  )

  return ResourceManagingExecutorFactory(
      executor_stack_fn=composing_executor_factory.create_executor)
Exemplo n.º 19
0
def create_remote_executor():
    port = portpicker.pick_unused_port()
    channel = grpc.insecure_channel('localhost:{}'.format(port))
    return remote_executor.RemoteExecutor(channel, 'REQUEST_REPLY')
Exemplo n.º 20
0
def remote_executor_factory(
        channels: List[grpc.Channel],
        rpc_mode: str = 'REQUEST_REPLY',
        thread_pool_executor: Optional[futures.Executor] = None,
        dispose_batch_size: int = 20,
        max_fanout: int = 100) -> executor_factory.ExecutorFactory:
    """Create an executor backed by remote workers.

  Args:
    channels: A list of `grpc.Channels` hosting services which can execute TFF
      work.
    rpc_mode: A string specifying the connection mode between the local host and
      `channels`.
    thread_pool_executor: Optional concurrent.futures.Executor used to wait for
      the reply to a streaming RPC message. Uses the default Executor if not
      specified.
    dispose_batch_size: The batch size for requests to dispose of remote worker
      values. Lower values will result in more requests to the remote worker,
      but will result in values being cleaned up sooner and therefore may result
      in lower memory usage on the remote worker.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.
  """
    py_typecheck.check_type(channels, list)
    if not channels:
        raise ValueError('The list of channels cannot be empty.')
    py_typecheck.check_type(rpc_mode, str)
    if thread_pool_executor is not None:
        py_typecheck.check_type(thread_pool_executor, futures.Executor)
    py_typecheck.check_type(dispose_batch_size, int)
    py_typecheck.check_type(max_fanout, int)

    remote_executors = []
    for channel in channels:
        remote_executors.append(
            remote_executor.RemoteExecutor(
                channel=channel,
                rpc_mode=rpc_mode,
                thread_pool_executor=thread_pool_executor,
                dispose_batch_size=dispose_batch_size))

    def _get_event_loop():
        should_close_loop = False
        try:
            loop = asyncio.get_event_loop()
            if loop.is_closed():
                loop = asyncio.new_event_loop()
                should_close_loop = True
        except RuntimeError:
            loop = asyncio.new_event_loop()
            should_close_loop = True
        return loop, should_close_loop

    def _configure_remote_executor(ex, cardinalities, loop):
        """Configures `ex` to run the appropriate number of clients."""
        if loop.is_running():
            asyncio.run_coroutine_threadsafe(
                ex.set_cardinalities(cardinalities), loop)
        else:
            loop.run_until_complete(ex.set_cardinalities(cardinalities))
        return

    def _configure_remote_workers(cardinalities):
        loop, must_close_loop = _get_event_loop()
        try:
            if not cardinalities.get(placement_literals.CLIENTS):
                for ex in remote_executors:
                    _configure_remote_executor(ex, cardinalities, loop)
                return [
                    _wrap_executor_in_threading_stack(e)
                    for e in remote_executors
                ]

            remaining_clients = cardinalities[placement_literals.CLIENTS]
            live_workers = []
            for ex_idx, ex in enumerate(remote_executors):
                remaining_executors = len(remote_executors) - ex_idx
                num_clients_to_host = remaining_clients // remaining_executors
                remaining_clients -= num_clients_to_host
                if num_clients_to_host > 0:
                    _configure_remote_executor(
                        ex, {placement_literals.CLIENTS: num_clients_to_host},
                        loop)
                    live_workers.append(ex)
        finally:
            if must_close_loop:
                loop.stop()
                loop.close()
        return [_wrap_executor_in_threading_stack(e) for e in live_workers]

    flat_stack_fn = _configure_remote_workers
    unplaced_ex_factory = UnplacedExecutorFactory(use_caching=False)
    composing_executor_factory = ComposingExecutorFactory(
        max_fanout=max_fanout,
        unplaced_ex_factory=unplaced_ex_factory,
        flat_stack_fn=flat_stack_fn,
    )

    return ResourceManagingExecutorFactory(
        executor_stack_fn=composing_executor_factory.create_executor,
        ensure_closed=remote_executors)
Exemplo n.º 21
0
def remote_executor_factory(
    channels: List[grpc.Channel],
    thread_pool_executor: Optional[futures.Executor] = None,
    dispose_batch_size: int = 20,
    max_fanout: int = 100,
    default_num_clients: int = 0,
) -> executor_factory.ExecutorFactory:
    """Create an executor backed by remote workers.

  Args:
    channels: A list of `grpc.Channels` hosting services which can execute TFF
      work.
    thread_pool_executor: Optional concurrent.futures.Executor used to wait for
      the reply to a streaming RPC message. Uses the default Executor if not
      specified.
    dispose_batch_size: The batch size for requests to dispose of remote worker
      values. Lower values will result in more requests to the remote worker,
      but will result in values being cleaned up sooner and therefore may result
      in lower memory usage on the remote worker.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.
    default_num_clients: The number of clients to use for simulations where the
      number of clients cannot be inferred. Usually the number of clients will
      be inferred from the number of values passed to computations which accept
      client-placed values. However, when this inference isn't possible (such as
      in the case of a no-argument or non-federated computation) this default
      will be used instead.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.
  """
    py_typecheck.check_type(channels, list)
    if not channels:
        raise ValueError('The list of channels cannot be empty.')
    if thread_pool_executor is not None:
        py_typecheck.check_type(thread_pool_executor, futures.Executor)
    py_typecheck.check_type(dispose_batch_size, int)
    py_typecheck.check_type(max_fanout, int)
    py_typecheck.check_type(default_num_clients, int)

    remote_executors = []
    for channel in channels:
        remote_executors.append(
            remote_executor.RemoteExecutor(
                channel=channel,
                thread_pool_executor=thread_pool_executor,
                dispose_batch_size=dispose_batch_size))

    def _flat_stack_fn(cardinalities):
        num_clients = cardinalities.get(placements.CLIENTS,
                                        default_num_clients)
        return _configure_remote_workers(num_clients, remote_executors)

    unplaced_ex_factory = UnplacedExecutorFactory(use_caching=False)
    composing_executor_factory = ComposingExecutorFactory(
        max_fanout=max_fanout,
        unplaced_ex_factory=unplaced_ex_factory,
        flat_stack_fn=_flat_stack_fn,
    )

    return ReconstructOnChangeExecutorFactory(
        underlying_stack=composing_executor_factory,
        ensure_closed=remote_executors,
        change_query=_CardinalitiesOrReadyListChanged(
            maybe_ready_list=remote_executors))