def test_context(rpc_mode='REQUEST_REPLY'):
  port = portpicker.pick_unused_port()
  server_pool = logging_pool.pool(max_workers=1)
  server = grpc.server(server_pool)
  server.add_insecure_port('[::]:{}'.format(port))
  target_executor = executor_stacks.local_executor_factory(
      num_clients=3).create_executor({})
  tracer = executor_test_utils.TracingExecutor(target_executor)
  service = executor_service.ExecutorService(tracer)
  executor_pb2_grpc.add_ExecutorServicer_to_server(service, server)
  server.start()
  channel = grpc.insecure_channel('localhost:{}'.format(port))
  remote_exec = remote_executor.RemoteExecutor(channel, rpc_mode)
  executor = lambda_executor.LambdaExecutor(remote_exec)
  set_default_executor.set_default_executor(
      executor_factory.ExecutorFactoryImpl(lambda _: executor))
  try:
    yield collections.namedtuple('_', 'executor tracer')(executor, tracer)
  finally:
    set_default_executor.set_default_executor()
    try:
      channel.close()
    except AttributeError:
      pass  # Public gRPC channel doesn't support close()
    finally:
      server.stop(None)
Beispiel #2
0
def worker_pool_executor_factory(executors,
                                 max_fanout=100
                                ) -> executor_factory.ExecutorFactory:
  """Create an executor backed by a worker pool.

  Args:
    executors: A list of `tff.framework.Executor` instances that forward work to
      workers in the worker pool. These can be any type of executors, but in
      most scenarios, they will be instances of `tff.framework.RemoteExecutor`.
    max_fanout: The maximum fanout at any point in the aggregation hierarchy. If
      `num_clients > max_fanout`, the constructed executor stack will consist of
      multiple levels of aggregators. The height of the stack will be on the
      order of `log(num_clients) / log(max_fanout)`.

  Returns:
    An instance of `executor_factory.ExecutorFactory` encapsulating the
    executor construction logic specified above.
  """
  py_typecheck.check_type(executors, list)
  py_typecheck.check_type(max_fanout, int)
  if not executors:
    raise ValueError('The list executors cannot be empty.')
  if max_fanout < 2:
    raise ValueError('Max fanout must be greater than 1.')
  executors = [_complete_stack(e) for e in executors]

  def _stack_fn(cardinalities):
    del cardinalities  # Unused
    return _aggregate_stacks(executors, max_fanout)

  return executor_factory.ExecutorFactoryImpl(executor_stack_fn=_stack_fn)
    def test_cleanup_succeeds_without_init(self):
        def _stack_fn(x):
            del x  # Unused
            return eager_executor.EagerExecutor()

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        factory.clean_up_executors()
    def test_concrete_class_instantiates_stack_fn(self):
        def _stack_fn(x):
            del x  # Unused
            return eager_executor.EagerExecutor()

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        self.assertIsInstance(factory, executor_factory.ExecutorFactoryImpl)
Beispiel #5
0
def _create_inferred_cardinality_factory(
        max_fanout, stack_func,
        clients_per_thread) -> executor_factory.ExecutorFactory:
    """Creates executor function with variable cardinality."""
    def _create_variable_clients_executors(cardinalities):
        """Constructs executor stacks from `dict` argument."""
        py_typecheck.check_type(cardinalities, dict)
        for k, v in cardinalities.items():
            py_typecheck.check_type(k, placement_literals.PlacementLiteral)
            if k not in [
                    placement_literals.CLIENTS, placement_literals.SERVER
            ]:
                raise ValueError('Unsupported placement: {}.'.format(k))
            if v <= 0:
                raise ValueError(
                    'Cardinality must be at '
                    'least one; you have passed {} for placement {}.'.format(
                        v, k))

        return _create_full_stack(
            cardinalities.get(placement_literals.CLIENTS, 0), max_fanout,
            stack_func, clients_per_thread)

    return executor_factory.ExecutorFactoryImpl(
        executor_stack_fn=_create_variable_clients_executors)
    def test_call_constructs_executor(self):
        def _stack_fn(x):
            del x  # Unused
            return eager_executor.EagerExecutor()

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        ex = factory.create_executor({})
        self.assertIsInstance(ex, executor_base.Executor)
    def test_cleanup_calls_close(self):
        ex = eager_executor.EagerExecutor()
        ex.close = mock.MagicMock()

        def _stack_fn(x):
            del x  # Unused
            return ex

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        factory.create_executor({})
        factory.clean_up_executors()
        ex.close.assert_called_once()
Beispiel #8
0
def _create_explicit_cardinality_factory(
    num_clients, max_fanout, stack_func,
    clients_per_thread) -> executor_factory.ExecutorFactory:
  """Creates executor function with fixed cardinality."""

  def _return_executor(cardinalities):
    n_requested_clients = cardinalities.get(placement_literals.CLIENTS)
    if n_requested_clients is not None and n_requested_clients != num_clients:
      raise ValueError('Expected to construct an executor with {} clients, '
                       'but executor is hardcoded for {}'.format(
                           n_requested_clients, num_clients))
    return _create_full_stack(num_clients, max_fanout, stack_func,
                              clients_per_thread)

  return executor_factory.ExecutorFactoryImpl(
      executor_stack_fn=_return_executor)
    def setUp(self):
        super().setUp()
        # 2 clients per worker stack * 3 worker stacks * 2 middle stacks
        self._num_clients = 12

        def _stack_fn(x):
            del x  # Unused
            return _create_middle_stack([
                _create_middle_stack(
                    [_create_worker_stack() for _ in range(3)]),
                _create_middle_stack(
                    [_create_worker_stack() for _ in range(3)])
            ])

        set_default_executor.set_default_executor(
            executor_factory.ExecutorFactoryImpl(_stack_fn))
    def test_construction_with_multiple_cardinalities_reuses_existing_stacks(
            self):
        ex = eager_executor.EagerExecutor()
        ex.close = mock.MagicMock()
        num_times_invoked = 0

        def _stack_fn(x):
            del x  # Unused
            nonlocal num_times_invoked
            num_times_invoked += 1
            return ex

        factory = executor_factory.ExecutorFactoryImpl(_stack_fn)
        for _ in range(2):
            factory.create_executor({})
            factory.create_executor({placement_literals.SERVER: 1})
        self.assertEqual(num_times_invoked, 2)
    def test_basic_functionality(self):
        @computations.tf_computation(computation_types.SequenceType(tf.int32))
        def comp(ds):
            return ds.take(5).reduce(np.int32(0), lambda x, y: x + y)

        set_default_executor.set_default_executor(
            executor_factory.ExecutorFactoryImpl(
                lambda _: eager_executor.EagerExecutor()))

        ds = tf.data.Dataset.range(1).map(lambda x: tf.constant(5)).repeat()
        v = comp(ds)
        self.assertEqual(v, 25)

        set_default_executor.set_default_executor()
        self.assertIn(
            'ExecutionContext',
            str(type(context_stack_impl.context_stack.current).__name__))
  def test_end_to_end(self):

    @computations.tf_computation(tf.int32)
    def add_one(x):
      return tf.add(x, 1)

    ex = concurrent_executor.ConcurrentExecutor(eager_executor.EagerExecutor())

    set_default_executor.set_default_executor(
        executor_factory.ExecutorFactoryImpl(lambda _: ex))

    self.assertEqual(add_one(7), 8)

    # After this invocation, the ConcurrentExecutor has been closed, and needs
    # to be re-initialized.

    self.assertEqual(add_one(8), 9)

    set_default_executor.set_default_executor()
Beispiel #13
0
def test_runs_tf(test_obj, executor):
    """Tests `executor` can run a minimal TF computation."""
    py_typecheck.check_type(executor, executor_base.Executor)
    set_default_executor.set_default_executor(
        executor_factory.ExecutorFactoryImpl(lambda _: executor))
    test_obj.assertEqual(_dummy_tf_computation(), 10)