示例#1
0
def worker_pool_executor_factory(executors,
                                 max_fanout=100
                                 ) -> executor_factory.ExecutorFactory:
    """Create an executor backed by a worker pool.

  Args:
    executors: A list of `tff.framework.Executor` instances that forward work to
      workers in the worker pool. These can be any type of executors, but in
      most scenarios, they will be instances of `tff.framework.RemoteExecutor`.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.
  """
    py_typecheck.check_type(executors, list)
    py_typecheck.check_type(max_fanout, int)
    if not executors:
        raise ValueError('The list executors cannot be empty.')
    if max_fanout < 2:
        raise ValueError('Max fanout must be greater than 1.')
    executors = [_complete_stack(e) for e in executors]

    def _stack_fn(cardinalities):
        del cardinalities  # Unused
        return _aggregate_stacks(executors, max_fanout)

    return executor_factory.ExecutorFactoryImpl(executor_stack_fn=_stack_fn)
    def test_concrete_class_instantiates_stack_fn(self):
        def _stack_fn(x):
            del x  # Unused
            return eager_tf_executor.EagerTFExecutor()

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        self.assertIsInstance(factory, executor_factory.ExecutorFactoryImpl)
示例#3
0
def worker_pool_executor_factory(executors,
                                 max_fanout=100
                                 ) -> executor_factory.ExecutorFactory:
    """Create an executor backed by a worker pool.

  Args:
    executors: A list of `tff.framework.Executor` instances that forward work to
      workers in the worker pool. These can be any type of executors, but in
      most scenarios, they will be instances of `tff.framework.RemoteExecutor`.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.
  """
    py_typecheck.check_type(executors, list)
    if not executors:
        raise ValueError('The list executors cannot be empty.')
    executors = [_wrap_executor_in_threading_stack(e) for e in executors]
    flat_stack_fn = lambda _: executors
    unplaced_ex_factory = UnplacedExecutorFactory(use_caching=True)
    composing_executor_factory = ComposingExecutorFactory(
        max_fanout=max_fanout,
        unplaced_ex_factory=unplaced_ex_factory,
        flat_stack_fn=flat_stack_fn,
    )
    return executor_factory.ExecutorFactoryImpl(
        executor_stack_fn=composing_executor_factory.create_executor)
    def test_cleanup_succeeds_without_init(self):
        def _stack_fn(x):
            del x  # Unused
            return eager_tf_executor.EagerTFExecutor()

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        factory.clean_up_executors()
示例#5
0
  def test_executor_service_create_one_arg_computation_value_and_call(self):
    ex_factory = executor_factory.ExecutorFactoryImpl(
        lambda _: eager_tf_executor.EagerTFExecutor())
    env = TestEnv(ex_factory)

    @computations.tf_computation(tf.int32)
    def comp(x):
      return tf.add(x, 1)

    value_proto, _ = executor_service_utils.serialize_value(comp)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    comp_ref = response.value_ref

    value_proto, _ = executor_service_utils.serialize_value(10, tf.int32)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    arg_ref = response.value_ref

    response = env.stub.CreateCall(
        executor_pb2.CreateCallRequest(
            function_ref=comp_ref, argument_ref=arg_ref))
    self.assertIsInstance(response, executor_pb2.CreateCallResponse)
    value_id = str(response.value_ref.id)
    value = env.get_value(value_id)
    self.assertEqual(value, 11)
    del env
示例#6
0
def _create_inferred_cardinality_factory(
        max_fanout, stack_func,
        clients_per_thread) -> executor_factory.ExecutorFactory:
    """Creates executor function with variable cardinality."""
    def _create_variable_clients_executors(cardinalities):
        """Constructs executor stacks from `dict` argument."""
        py_typecheck.check_type(cardinalities, dict)
        for k, v in cardinalities.items():
            py_typecheck.check_type(k, placement_literals.PlacementLiteral)
            if k not in [
                    placement_literals.CLIENTS, placement_literals.SERVER
            ]:
                raise ValueError('Unsupported placement: {}.'.format(k))
            if v <= 0:
                raise ValueError(
                    'Cardinality must be at '
                    'least one; you have passed {} for placement {}.'.format(
                        v, k))

        return _create_full_stack(
            cardinalities.get(placement_literals.CLIENTS, 0), max_fanout,
            stack_func, clients_per_thread)

    return executor_factory.ExecutorFactoryImpl(
        executor_stack_fn=_create_variable_clients_executors)
    def test_call_constructs_executor(self):
        def _stack_fn(x):
            del x  # Unused
            return eager_tf_executor.EagerTFExecutor()

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        ex = factory.create_executor({})
        self.assertIsInstance(ex, executor_base.Executor)
示例#8
0
  def test_executor_service_slowly_create_tensor_value(self):

    class SlowExecutorValue(executor_value_base.ExecutorValue):

      def __init__(self, v, t):
        self._v = v
        self._t = t

      @property
      def type_signature(self):
        return self._t

      async def compute(self):
        return self._v

    class SlowExecutor(executor_base.Executor):

      def __init__(self):
        self.status = 'idle'
        self.busy = threading.Event()
        self.done = threading.Event()

      async def create_value(self, value, type_spec=None):
        self.status = 'busy'
        self.busy.set()
        self.done.wait()
        self.status = 'done'
        return SlowExecutorValue(value, type_spec)

      async def create_call(self, comp, arg=None):
        raise NotImplementedError

      async def create_struct(self, elements):
        raise NotImplementedError

      async def create_selection(self, source, index=None, name=None):
        raise NotImplementedError

      def close(self):
        pass

    ex = SlowExecutor()
    ex_factory = executor_factory.ExecutorFactoryImpl(lambda _: ex)
    env = TestEnv(ex_factory)
    self.assertEqual(ex.status, 'idle')
    value_proto, _ = executor_service_utils.serialize_value(10, tf.int32)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    ex.busy.wait()
    self.assertEqual(ex.status, 'busy')
    ex.done.set()
    value = env.get_value(response.value_ref.id)
    self.assertEqual(ex.status, 'done')
    self.assertEqual(value, 10)
示例#9
0
  def test_with_executor_factory(self):
    context_stack = context_stack_impl.context_stack
    executor_factory_impl = executor_factory.ExecutorFactoryImpl(lambda _: None)
    self.assertNotIsInstance(context_stack.current,
                             execution_context.ExecutionContext)

    default_executor.set_default_executor(executor_factory_impl)

    self.assertIsInstance(context_stack.current,
                          execution_context.ExecutionContext)
    self.assertIs(context_stack.current._executor_factory,
                  executor_factory_impl)
    def test_cleanup_calls_close(self):
        ex = eager_tf_executor.EagerTFExecutor()
        ex.close = mock.MagicMock()

        def _stack_fn(x):
            del x  # Unused
            return ex

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        factory.create_executor({})
        factory.clean_up_executors()
        ex.close.assert_called_once()
示例#11
0
def test_runs_tf(test_obj, executor):
    """Tests `executor` can run a minimal TF computation."""
    py_typecheck.check_type(executor, executor_base.Executor)

    @computations.tf_computation
    def comp():
        return tf.math.add(5, 5)

    executor_factory_impl = executor_factory.ExecutorFactoryImpl(
        lambda _: executor)
    with _execution_context(executor_factory_impl):
        test_obj.assertEqual(comp(), 10)
示例#12
0
 def test_executor_service_create_tensor_value(self):
   ex_factory = executor_factory.ExecutorFactoryImpl(
       lambda _: eager_tf_executor.EagerTFExecutor())
   env = TestEnv(ex_factory)
   value_proto, _ = executor_service_utils.serialize_value(
       tf.constant(10.0).numpy(), tf.float32)
   response = env.stub.CreateValue(
       executor_pb2.CreateValueRequest(value=value_proto))
   self.assertIsInstance(response, executor_pb2.CreateValueResponse)
   value_id = str(response.value_ref.id)
   value = env.get_value(value_id)
   self.assertEqual(value, 10.0)
   del env
示例#13
0
    def test_server_context_shuts_down_uncaught_exception(
            self, mock_logging_info):

        ex = eager_tf_executor.EagerTFExecutor()
        ex_factory = executor_factory.ExecutorFactoryImpl(lambda _: ex)

        with self.assertRaises(TypeError):
            with server_utils.server_context(
                    ex_factory, 1, portpicker.pick_unused_port()) as server:
                time.sleep(1)
                raise TypeError

        mock_logging_info.assert_called_once_with('Shutting down server.')
示例#14
0
    def test_server_context_shuts_down_under_keyboard_interrupt(
            self, mock_logging_info):

        ex = eager_tf_executor.EagerTFExecutor()
        ex_factory = executor_factory.ExecutorFactoryImpl(lambda _: ex)

        with server_utils.server_context(
                ex_factory, 1, portpicker.pick_unused_port()) as server:
            time.sleep(1)
            raise KeyboardInterrupt

        mock_logging_info.assert_has_calls([
            mock.call('Server stopped by KeyboardInterrupt.'),
            mock.call('Shutting down server.')
        ])
示例#15
0
def _create_explicit_cardinality_factory(
        num_clients, max_fanout, stack_func,
        clients_per_thread) -> executor_factory.ExecutorFactory:
    """Creates executor function with fixed cardinality."""
    def _return_executor(cardinalities):
        n_requested_clients = cardinalities.get(placement_literals.CLIENTS)
        if n_requested_clients is not None and n_requested_clients != num_clients:
            raise ValueError(
                'Expected to construct an executor with {} clients, '
                'but executor is hardcoded for {}'.format(
                    n_requested_clients, num_clients))
        return _create_full_stack(num_clients, max_fanout, stack_func,
                                  clients_per_thread)

    return executor_factory.ExecutorFactoryImpl(
        executor_stack_fn=_return_executor)
    def test_construction_with_multiple_cardinalities_reuses_existing_stacks(
            self):
        ex = eager_tf_executor.EagerTFExecutor()
        ex.close = mock.MagicMock()
        num_times_invoked = 0

        def _stack_fn(x):
            del x  # Unused
            nonlocal num_times_invoked
            num_times_invoked += 1
            return ex

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        for _ in range(2):
            factory.create_executor({})
            factory.create_executor({placement_literals.SERVER: 1})
        self.assertEqual(num_times_invoked, 2)
  def test_basic_functionality(self):

    @computations.tf_computation(computation_types.SequenceType(tf.int32))
    def comp(ds):
      return ds.take(5).reduce(np.int32(0), lambda x, y: x + y)

    set_default_executor.set_default_executor(
        executor_factory.ExecutorFactoryImpl(
            lambda _: eager_tf_executor.EagerTFExecutor()))

    ds = tf.data.Dataset.range(1).map(lambda x: tf.constant(5)).repeat()
    v = comp(ds)
    self.assertEqual(v, 25)

    set_default_executor.set_default_executor()
    self.assertIn('ExecutionContext',
                  str(type(context_stack_impl.context_stack.current).__name__))
def test_context(rpc_mode='REQUEST_REPLY'):
    port = portpicker.pick_unused_port()
    server_pool = logging_pool.pool(max_workers=1)
    server = grpc.server(server_pool)
    server.add_insecure_port('[::]:{}'.format(port))
    target_factory = executor_stacks.local_executor_factory(num_clients=3)
    tracers = []

    def _tracer_fn(cardinalities):
        tracer = executor_test_utils.TracingExecutor(
            target_factory.create_executor(cardinalities))
        tracers.append(tracer)
        return tracer

    service = executor_service.ExecutorService(
        executor_factory.ExecutorFactoryImpl(_tracer_fn))
    executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
    server.start()

    channel = grpc.insecure_channel('localhost:{}'.format(port))
    stub = executor_pb2_grpc.ExecutorStub(channel)
    serialized_cards = executor_service_utils.serialize_cardinalities(
        {placement_literals.CLIENTS: 3})
    stub.SetCardinalities(
        executor_pb2.SetCardinalitiesRequest(cardinalities=serialized_cards))

    remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        remote_exec)
    try:
        yield collections.namedtuple('_', 'executor tracers')(executor,
                                                              tracers)
    finally:
        executor.close()
        for tracer in tracers:
            tracer.close()
        try:
            channel.close()
        except AttributeError:
            pass  # Public gRPC channel doesn't support close()
        finally:
            server.stop(None)
示例#19
0
  def test_executor_service_create_and_select_from_tuple(self):
    ex_factory = executor_factory.ExecutorFactoryImpl(
        lambda _: eager_tf_executor.EagerTFExecutor())
    env = TestEnv(ex_factory)

    value_proto, _ = executor_service_utils.serialize_value(10, tf.int32)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    ten_ref = response.value_ref
    self.assertEqual(env.get_value(ten_ref.id), 10)

    value_proto, _ = executor_service_utils.serialize_value(20, tf.int32)
    response = env.stub.CreateValue(
        executor_pb2.CreateValueRequest(value=value_proto))
    self.assertIsInstance(response, executor_pb2.CreateValueResponse)
    twenty_ref = response.value_ref
    self.assertEqual(env.get_value(twenty_ref.id), 20)

    response = env.stub.CreateStruct(
        executor_pb2.CreateStructRequest(element=[
            executor_pb2.CreateStructRequest.Element(
                name='a', value_ref=ten_ref),
            executor_pb2.CreateStructRequest.Element(
                name='b', value_ref=twenty_ref)
        ]))
    self.assertIsInstance(response, executor_pb2.CreateStructResponse)
    tuple_ref = response.value_ref
    self.assertEqual(str(env.get_value(tuple_ref.id)), '<a=10,b=20>')

    for arg_name, arg_val, result_val in [('name', 'a', 10), ('name', 'b', 20),
                                          ('index', 0, 10), ('index', 1, 20)]:
      response = env.stub.CreateSelection(
          executor_pb2.CreateSelectionRequest(
              source_ref=tuple_ref, **{arg_name: arg_val}))
      self.assertIsInstance(response, executor_pb2.CreateSelectionResponse)
      selection_ref = response.value_ref
      self.assertEqual(env.get_value(selection_ref.id), result_val)

    del env
示例#20
0
 def test_executor_service_value_unavailable_after_dispose(self):
   ex_factory = executor_factory.ExecutorFactoryImpl(
       lambda _: eager_tf_executor.EagerTFExecutor())
   env = TestEnv(ex_factory)
   value_proto, _ = executor_service_utils.serialize_value(
       tf.constant(10.0).numpy(), tf.float32)
   # Create the value
   response = env.stub.CreateValue(
       executor_pb2.CreateValueRequest(value=value_proto))
   self.assertIsInstance(response, executor_pb2.CreateValueResponse)
   value_id = str(response.value_ref.id)
   # Check that the value appears in the _values map
   env.get_value_future_directly(value_id)
   # Dispose of the value
   dispose_request = executor_pb2.DisposeRequest()
   dispose_request.value_ref.append(response.value_ref)
   response = env.stub.Dispose(dispose_request)
   self.assertIsInstance(response, executor_pb2.DisposeResponse)
   # Check that the value is gone from the _values map
   # get_value_future_directly is used here so that we can catch the
   # exception rather than having it occur on the GRPC thread.
   with self.assertRaises(KeyError):
     env.get_value_future_directly(value_id)
示例#21
0
def create_test_executor_factory():
    executor = eager_tf_executor.EagerTFExecutor()
    return executor_factory.ExecutorFactoryImpl(lambda _: executor)
示例#22
0
def local_executor_factory(
    num_clients=None,
    max_fanout=100,
    num_client_executors=32,
    server_tf_device=None,
    client_tf_devices=tuple()
) -> executor_factory.ExecutorFactory:
    """Constructs an executor factory to execute computations locally.

  Note: The `tff.federated_secure_sum()` intrinsic is not implemented by this
  executor.

  Args:
    num_clients: The number of clients. If specified, the executor factory
      function returned by `local_executor_factory` will be configured to have
      exactly `num_clients` clients. If unspecified (`None`), then the function
      returned will attempt to infer cardinalities of all placements for which
      it is passed values.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.
    num_client_executors: The number of distinct client executors to run
      concurrently; executing more clients than this number results in multiple
      clients having their work pinned on a single executor in a synchronous
      fashion.
    server_tf_device: A `tf.config.LogicalDevice` to place server and other
      computation without explicit TFF placement.
    client_tf_devices: List/tuple of `tf.config.LogicalDevice` to place clients
      for simulation. Possibly accelerators returned by
      `tf.config.list_logical_devices()`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.

  Raises:
    ValueError: If the number of clients is specified and not one or larger.
  """
    if server_tf_device is not None:
        py_typecheck.check_type(server_tf_device, tf.config.LogicalDevice)
    py_typecheck.check_type(client_tf_devices, (tuple, list))
    py_typecheck.check_type(max_fanout, int)
    py_typecheck.check_type(num_client_executors, int)
    if num_clients is not None:
        py_typecheck.check_type(num_clients, int)
    if max_fanout < 2:
        raise ValueError('Max fanout must be greater than 1.')
    unplaced_ex_factory = UnplacedExecutorFactory(
        use_caching=True,
        server_device=server_tf_device,
        client_devices=client_tf_devices)
    federating_executor_factory = FederatingExecutorFactory(
        num_client_executors=num_client_executors,
        unplaced_ex_factory=unplaced_ex_factory,
        num_clients=num_clients,
        use_sizing=False)

    def _factory_fn(
        cardinalities: executor_factory.CardinalitiesType
    ) -> executor_base.Executor:
        return _create_full_stack(
            cardinalities,
            max_fanout,
            stack_func=federating_executor_factory.create_executor,
            unplaced_ex_factory=unplaced_ex_factory)

    return executor_factory.ExecutorFactoryImpl(_factory_fn)
示例#23
0
def thread_debugging_executor_factory(
    num_clients=None,
    clients_per_thread=1,
) -> executor_factory.ExecutorFactory:
    r"""Constructs a simplified execution stack to execute local computations.

  The constructed executors support a limited set of TFF's computations. In
  particular, the debug executor can only resolve references at the top level
  of the stack, and therefore assumes that all local computation is expressed
  in pure TensorFlow.

  The debug executor makes particular guarantees about the structure of the
  stacks it constructs which are intended to make them maximally easy to reason
  about. That is, the debug executor will essentially execute exactly the
  computation it is passed (in particular, this implies that there are no
  caching layers), and uses its declared inability to execute arbitrary TFF
  lambdas to reduce the complexity of the constructed stack. Every debugging
  executor will have a similar structure, the simplest structure that can
  execute the full expressivity of TFF while running each client in a dedicated
  thread:


                        ReferenceResolvingExecutor
                                  |
                            FederatingExecutor
                          /       ...         \
    ThreadDelegatingExecutor                  ThreadDelegatingExecutor
                |                                         |
        EagerTFExecutor                            EagerTFExecutor


  This structure can be useful in understanding the concurrency pattern of TFF
  execution, and where the TFF runtime infers data dependencies.

  Args:
    num_clients: The number of clients. If specified, the executor factory
      function returned by `local_executor_factory` will be configured to have
      exactly `num_clients` clients. If unspecified (`None`), then the function
      returned will attempt to infer cardinalities of all placements for which
      it is passed values.
    clients_per_thread: Integer number of clients for each of TFF's threads to
      run in sequence. Increasing `clients_per_thread` therefore reduces the
      concurrency of the TFF runtime, which can be useful if client work is very
      lightweight or models are very large and multiple copies cannot fit in
      memory.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.

  Raises:
    ValueError: If the number of clients is specified and not one or larger.
  """
    py_typecheck.check_type(clients_per_thread, int)
    if num_clients is not None:
        py_typecheck.check_type(num_clients, int)
    unplaced_ex_factory = UnplacedExecutorFactory(
        use_caching=False,
        can_resolve_references=False,
    )
    federating_executor_factory = FederatingExecutorFactory(
        clients_per_thread=clients_per_thread,
        unplaced_ex_factory=unplaced_ex_factory,
        num_clients=num_clients,
        use_sizing=False)

    return executor_factory.ExecutorFactoryImpl(
        federating_executor_factory.create_executor)
示例#24
0
def local_executor_factory(
    num_clients=None,
    max_fanout=100,
    clients_per_thread=1,
    server_tf_device=None,
    client_tf_devices=tuple()
) -> executor_factory.ExecutorFactory:
    """Constructs an executor factory to execute computations locally.

  Note: The `tff.federated_secure_sum()` intrinsic is not implemented by this
  executor.

  Args:
    num_clients: The number of clients. If specified, the executor factory
      function returned by `local_executor_factory` will be configured to have
      exactly `num_clients` clients. If unspecified (`None`), then the function
      returned will attempt to infer cardinalities of all placements for which
      it is passed values.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.
    clients_per_thread: Integer number of clients for each of TFF's threads to
      run in sequence. Increasing `clients_per_thread` therefore reduces the
      concurrency of the TFF runtime, which can be useful if client work is
      very lightweight or models are very large and multiple copies cannot fit
      in memory.
    server_tf_device: A `tf.config.LogicalDevice` to place server and other
      computation without explicit TFF placement.
    client_tf_devices: List/tuple of `tf.config.LogicalDevice` to place clients
      for simulation. Possibly accelerators returned by
      `tf.config.list_logical_devices()`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.

  Raises:
    ValueError: If the number of clients is specified and not one or larger.
  """
    if server_tf_device is not None:
        py_typecheck.check_type(server_tf_device, tf.config.LogicalDevice)
    py_typecheck.check_type(client_tf_devices, (tuple, list))
    py_typecheck.check_type(max_fanout, int)
    py_typecheck.check_type(clients_per_thread, int)
    if num_clients is not None:
        py_typecheck.check_type(num_clients, int)
    if max_fanout < 2:
        raise ValueError('Max fanout must be greater than 1.')
    unplaced_ex_factory = UnplacedExecutorFactory(
        use_caching=True,
        server_device=server_tf_device,
        client_devices=client_tf_devices)
    federating_executor_factory = FederatingExecutorFactory(
        clients_per_thread=clients_per_thread,
        unplaced_ex_factory=unplaced_ex_factory,
        num_clients=num_clients,
        use_sizing=False)
    flat_stack_fn = create_minimal_length_flat_stack_fn(
        max_fanout, federating_executor_factory)
    full_stack_factory = ComposingExecutorFactory(
        max_fanout=max_fanout,
        unplaced_ex_factory=unplaced_ex_factory,
        flat_stack_fn=flat_stack_fn,
    )
    return executor_factory.ExecutorFactoryImpl(
        full_stack_factory.create_executor)
def test_runs_tf(test_obj, executor):
    """Tests `executor` can run a minimal TF computation."""
    py_typecheck.check_type(executor, executor_base.Executor)
    set_default_executor.set_default_executor(
        executor_factory.ExecutorFactoryImpl(lambda _: executor))
    test_obj.assertEqual(_dummy_tf_computation(), 10)
def create_test_executor_factory():
    executor = create_test_executor(num_clients=1)
    return executor_factory.ExecutorFactoryImpl(lambda _: executor)
示例#27
0
def create_test_executor_factory():
    executor = eager_tf_executor.EagerTFExecutor()
    executor = caching_executor.CachingExecutor(executor)
    executor = reference_resolving_executor.ReferenceResolvingExecutor(
        executor)
    return executor_factory.ExecutorFactoryImpl(lambda _: executor)
示例#28
0
def remote_executor_factory(
        channels,
        rpc_mode='REQUEST_REPLY',
        thread_pool_executor=None,
        dispose_batch_size=20,
        max_fanout=100) -> executor_factory.ExecutorFactory:
    """Create an executor backed by remote workers.

  Args:
    channels: A list of `grpc.Channels` hosting services which can execute TFF
      work.
    rpc_mode: A string specifying the connection mode between the local host and
      `channels`.
    thread_pool_executor: Optional concurrent.futures.Executor used to wait for
      the reply to a streaming RPC message. Uses the default Executor if not
      specified.
    dispose_batch_size: The batch size for requests to dispose of remote worker
      values. Lower values will result in more requests to the remote worker,
      but will result in values being cleaned up sooner and therefore may result
      in lower memory usage on the remote worker.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.
  """
    py_typecheck.check_type(channels, list)
    if not channels:
        raise ValueError('The list of channels cannot be empty.')
    remote_executors = [
        remote_executor.RemoteExecutor(channel, rpc_mode, thread_pool_executor,
                                       dispose_batch_size)
        for channel in channels
    ]

    def _configure_remote_executor(ex, cardinalities, loop):
        """Configures `ex` to run the appropriate number of clients."""
        loop.run_until_complete(ex.set_cardinalities(cardinalities))
        return

    def _configure_remote_workers(cardinalities):
        loop = asyncio.new_event_loop()
        try:
            if not cardinalities.get(placement_literals.CLIENTS):
                for ex in remote_executors:
                    _configure_remote_executor(ex, cardinalities, loop)
                return [
                    _wrap_executor_in_threading_stack(e)
                    for e in remote_executors
                ]

            remaining_clients = cardinalities[placement_literals.CLIENTS]
            clients_per_most_executors = remaining_clients // len(
                remote_executors)

            for ex in remote_executors[:-1]:
                _configure_remote_executor(
                    ex,
                    {placement_literals.CLIENTS: clients_per_most_executors},
                    loop)
                remaining_clients -= clients_per_most_executors
            _configure_remote_executor(
                remote_executors[-1],
                {placement_literals.CLIENTS: remaining_clients}, loop)
        finally:
            loop.close()
        return [_wrap_executor_in_threading_stack(e) for e in remote_executors]

    flat_stack_fn = _configure_remote_workers
    unplaced_ex_factory = UnplacedExecutorFactory(use_caching=True)
    composing_executor_factory = ComposingExecutorFactory(
        max_fanout=max_fanout,
        unplaced_ex_factory=unplaced_ex_factory,
        flat_stack_fn=flat_stack_fn,
    )

    return executor_factory.ExecutorFactoryImpl(
        executor_stack_fn=composing_executor_factory.create_executor)