예제 #1
0
    def _executor_stack_fn(
        cardinalities: executor_factory.CardinalitiesType
    ) -> Tuple[executor_base.Executor, List[sizing_executor.SizingExecutor]]:
        """The function passed to SizingExecutorFactoryImpl to convert cardinalities into executor stack.

    Unlike the function that is passed into ExecutorFactoryImpl, this one
    outputs the sizing executors as well.

    Args:
      cardinalities: Cardinality representation used to determine how many
        clients there are.

    Returns:
      A Tuple of the top level executor created from the cardinalities, and the
      list of sizing executors underneath the top level executors.
    """
        sizing_exs = []
        unplaced_ex_factory = UnplacedExecutorFactory(use_caching=True)

        def _standalone_stack_func(num_clients, num_client_executors):
            nonlocal sizing_exs
            stack, current_sizing_exs = _create_sizing_stack(
                num_clients,
                num_client_executors=num_client_executors,
                unplaced_ex_factory=unplaced_ex_factory)
            sizing_exs.extend(current_sizing_exs)
            return stack

        # Explicit case.
        if num_clients is not None:
            py_typecheck.check_type(num_clients, int)
            if num_clients <= 0:
                raise ValueError(
                    'If specifying `num_clients`, cardinality must be at '
                    'least one; you have passed {}.'.format(num_clients))
            n_requested_clients = cardinalities.get(placement_literals.CLIENTS)
            if n_requested_clients is not None and n_requested_clients != num_clients:
                raise ValueError(
                    'Expected to construct an executor with {} clients, '
                    'but executor is hardcoded for {}'.format(
                        n_requested_clients, num_clients))
            return _create_full_stack(
                num_clients,
                max_fanout,
                _standalone_stack_func,
                num_client_executors,
                unplaced_ex_factory=unplaced_ex_factory), sizing_exs
        # Inferred case.
        else:
            n_requested_clients = cardinalities.get(placement_literals.CLIENTS,
                                                    0)
            return _create_full_stack(
                n_requested_clients,
                max_fanout,
                _standalone_stack_func,
                num_client_executors,
                unplaced_ex_factory=unplaced_ex_factory), sizing_exs
예제 #2
0
 def _executor_fn(
     cardinalities: executor_factory.CardinalitiesType
 ) -> executor_bindings.Executor:
     if cardinalities.get(placements.CLIENTS) is None:
         cardinalities[placements.CLIENTS] = default_num_clients
     return executor_stack_bindings.create_remote_executor_stack(
         channels, cardinalities)
예제 #3
0
 def _factory_fn(
     cardinalities: executor_factory.CardinalitiesType
 ) -> executor_base.Executor:
   if cardinalities.get(placement_literals.CLIENTS, 0) < max_fanout:
     executor = federating_executor_factory.create_executor(cardinalities)
   else:
     executor = full_stack_factory.create_executor(cardinalities)
   sizing_executor_list = federating_executor_factory.sizing_executors
   return executor, sizing_executor_list
예제 #4
0
 def _validate_requested_clients(
         self, cardinalities: executor_factory.CardinalitiesType) -> int:
     num_requested_clients = cardinalities.get(placement_literals.CLIENTS)
     if num_requested_clients is None:
         if self._num_clients is not None:
             return self._num_clients
         else:
             return 0
     if (self._num_clients is not None
             and self._num_clients != num_requested_clients):
         raise ValueError(
             'FederatingStackFactory configured to return {} '
             'clients, but encountered a request for {} clients.'.format(
                 self._num_clients, num_requested_clients))
     return num_requested_clients
예제 #5
0
 def _validate_requested_clients(
         self, cardinalities: executor_factory.CardinalitiesType) -> int:
     num_requested_clients = cardinalities.get(placements.CLIENTS)
     if num_requested_clients is None:
         if self._num_clients is not None:
             return self._num_clients
         else:
             return 0
     if (self._num_clients is not None
             and self._num_clients != num_requested_clients):
         raise ValueError(
             'FederatingStackFactory configured to return {} '
             'clients, but encountered a request for {} clients.'
             'If your computation accepts CLIENTS-placed arguments, it is '
             'recommended to avoid setting the num_clients parameter in the TFF '
             'runtime.'.format(self._num_clients, num_requested_clients))
     return num_requested_clients
예제 #6
0
 def create_executor_list(
     cardinalities: executor_factory.CardinalitiesType
 ) -> List[executor_base.Executor]:
   num_clients = cardinalities.get(placement_literals.CLIENTS, 0)
   if num_clients < 0:
     raise ValueError('Number of clients cannot be negative.')
   elif num_clients < 1:
     return [
         federated_stack_factory.create_executor(cardinalities=cardinalities)
     ]
   executors = []
   while num_clients > 0:
     n = min(num_clients, max_clients_per_stack)
     sub_executor_cardinalities = {**cardinalities}
     sub_executor_cardinalities[placement_literals.CLIENTS] = n
     executors.append(
         federated_stack_factory.create_executor(sub_executor_cardinalities))
     num_clients -= n
   return executors
예제 #7
0
    def create_executor(
        self, cardinalities: executor_factory.CardinalitiesType
    ) -> executor_base.Executor:
        """Creates an executor hierarchy of maximum width `self._max_fanout`.

    If the `ComposingExecutorFactory` has been configured with a
    `federated_stack_factory`, this function will construct the minimum number
    of federated stacks necessary to run the configuration specified by
    `cardinalities`such that each stack has no more than `self._max_fanout`
    clients, and arrange these stacks in a hierarchy of width no greater than
    `self._max_fanout`.

    If the `ComposingExecutorFactory` has been configured with prebuilt
    child executors, this function will simply arrange these child executors in
    a stack of widtch no greater than `self._max_fanout`.

    Args:
      cardinalities: A mapping from placements to integers specifying the
        cardinalities at each placement

    Returns:
      An `executor_base.Executor` satisfying the conditions above.
    """
        if self._child_executors is not None:
            return self._aggregate_stacks(self._child_executors)
        num_clients = cardinalities.get(placement_literals.CLIENTS, 0)
        if num_clients < 0:
            raise ValueError('Number of clients cannot be negative.')
        if num_clients < 1:
            return self._federated_stack_factory.create_executor(
                cardinalities=cardinalities)
        else:
            executors = []
            while num_clients > 0:
                n = min(num_clients, self._max_fanout)
                sub_executor_cardinalities = {**cardinalities}
                sub_executor_cardinalities[placement_literals.CLIENTS] = n
                executors.append(
                    self._federated_stack_factory.create_executor(
                        sub_executor_cardinalities))
                num_clients -= n
            return self._aggregate_stacks(executors)
예제 #8
0
 def _executor_fn(
     cardinalities: executor_factory.CardinalitiesType
 ) -> executor_bindings.Executor:
     if cardinalities.get(placements.CLIENTS) is None:
         cardinalities[placements.CLIENTS] = default_num_clients
     num_clients = cardinalities[placements.CLIENTS]
     if max_concurrent_computation_calls > 0 and num_clients > max_concurrent_computation_calls:
         expected_concurrency_factor = math.ceil(
             num_clients / max_concurrent_computation_calls)
         _log_and_warn_on_sequential_execution(
             max_concurrent_computation_calls, num_clients,
             expected_concurrency_factor)
     tf_executor = executor_bindings.create_tensorflow_executor(
         max_concurrent_computation_calls)
     sub_federating_reference_resolving_executor = executor_bindings.create_reference_resolving_executor(
         tf_executor)
     federating_ex = executor_bindings.create_federating_executor(
         sub_federating_reference_resolving_executor, cardinalities)
     top_level_reference_resolving_ex = executor_bindings.create_reference_resolving_executor(
         federating_ex)
     return top_level_reference_resolving_ex
예제 #9
0
def _create_full_stack(
    cardinalities: executor_factory.CardinalitiesType,
    max_fanout: int,
    stack_func: Callable[[executor_factory.CardinalitiesType],
                         executor_base.Executor],
    unplaced_ex_factory: UnplacedExecutorFactory,
) -> executor_base.Executor:
    """Creates a full executor stack.

  Args:
    cardinalities: The cardinalities to create at each placement.
    max_fanout: The maximum fanout at any point in the hierarchy. Must be 2 or
      larger.
    stack_func: A function taking a dict of cardinalities and returning an
      `executor_base.Executor`.
    unplaced_ex_factory: The unplaced executor factory to use in constructing
      executors to execute unplaced computations in the hierarchy.

  Returns:
    An executor stack, potentially multi-level, that spans all clients.

  Raises:
    ValueError: If the number of clients or fanout are not as specified.
    RuntimeError: If the stack construction fails.
  """
    num_clients = cardinalities.get(placement_literals.CLIENTS, 0)
    py_typecheck.check_type(max_fanout, int)
    if num_clients < 0:
        raise ValueError('Number of clients cannot be negative.')
    if num_clients < 1:
        return stack_func(cardinalities=cardinalities)  # pytype: disable=wrong-keyword-args
    else:
        executors = []
        while num_clients > 0:
            n = min(num_clients, max_fanout)
            executors.append(
                stack_func(cardinalities={placement_literals.CLIENTS: n}))  # pytype: disable=wrong-keyword-args
            num_clients -= n
        return _aggregate_stacks(executors, max_fanout, unplaced_ex_factory)
예제 #10
0
 def _validate_requested_clients(
         self, cardinalities: executor_factory.CardinalitiesType) -> int:
     num_requested_clients = cardinalities.get(placements.CLIENTS)
     if num_requested_clients is None:
         return self._default_num_clients
     return num_requested_clients