Example #1
0
def build_federated_averaging_process(
    model_fn: Callable[[], model_lib.Model],
    client_optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer],
    server_optimizer_fn: Callable[
        [], tf.keras.optimizers.Optimizer] = DEFAULT_SERVER_OPTIMIZER_FN,
    client_weight_fn: Callable[[Any], tf.Tensor] = None,
    stateful_delta_aggregate_fn=None,
    stateful_model_broadcast_fn=None) -> tff.utils.IterativeProcess:
  """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`.
      The `apply_gradients` method of this optimizer is used to apply client
      updates to the server model. The default creates a
      `tf.keras.optimizers.SGD` with a learning rate of 1.0, which simply adds
      the average client delta to the server's model.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.
    stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and upates state. That is, it
      has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) ->
      (state@SERVER, aggregate@SERVER)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`. By default performs arithmetic mean
      aggregation, weighted by `client_weight_fn`.
    stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and upates state. That is, it has
      TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`,
      where the `value` type is `tff.learning.framework.ModelWeights`
      corresponding to the object returned by `model_fn`. By default performs
      identity broadcast.

  Returns:
    A `tff.utils.IterativeProcess`.
  """

  def client_fed_avg(model_fn):
    return _ClientFedAvg(model_fn(), client_optimizer_fn(), client_weight_fn)

  if stateful_delta_aggregate_fn is None:
    stateful_delta_aggregate_fn = optimizer_utils.build_stateless_mean()
  else:
    py_typecheck.check_type(stateful_delta_aggregate_fn,
                            tff.utils.StatefulAggregateFn)

  if stateful_model_broadcast_fn is None:
    stateful_model_broadcast_fn = optimizer_utils.build_stateless_broadcaster()
  else:
    py_typecheck.check_type(stateful_model_broadcast_fn,
                            tff.utils.StatefulBroadcastFn)

  return optimizer_utils.build_model_delta_optimizer_process(
      model_fn, client_fed_avg, server_optimizer_fn,
      stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
Example #2
0
 def test_fails_stateful_aggregate_and_process(self):
     model_weights_type = model_utils.weights_type_from_model(
         model_examples.LinearRegression)
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         federated_averaging.build_federated_averaging_process(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_delta_aggregate_fn=tff.utils.StatefulAggregateFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, value, weight=None:  # pylint: disable=g-long-lambda
                 (state, tff.federated_mean(value, weight))),
             aggregation_process=optimizer_utils.build_stateless_mean(
                 model_delta_type=model_weights_type.trainable))
Example #3
0
 def test_fails_stateful_aggregate_and_process(self):
     with tf.Graph().as_default():
         model_weights_type = tff.framework.type_from_tensors(
             model_utils.ModelWeights.from_model(
                 model_examples.LinearRegression()))
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         optimizer_utils.build_model_delta_optimizer_process(
             model_fn=model_examples.LinearRegression,
             model_to_client_delta_fn=DummyClientDeltaFn,
             server_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_delta_aggregate_fn=tff.utils.StatefulAggregateFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, value, weight=None:  # pylint: disable=g-long-lambda
                 (state, tff.federated_mean(value, weight))),
             aggregation_process=optimizer_utils.build_stateless_mean(
                 model_delta_type=model_weights_type.trainable))
Example #4
0
def build_federated_averaging_process(
        model_fn: Callable[[], model_lib.Model],
        client_optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer],
        server_optimizer_fn: Callable[
            [], tf.keras.optimizers.Optimizer] = DEFAULT_SERVER_OPTIMIZER_FN,
        client_weight_fn: Callable[[Any], tf.Tensor] = None,
        stateful_delta_aggregate_fn=None,
        stateful_model_broadcast_fn=None) -> tff.utils.IterativeProcess:
    """Builds an iterative process that performs federated averaging.

  This function creates a `tff.utils.IterativeProcess` that performs federated
  averaging on client models. The iterative process has the following methods:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a`tff.learning.framework.ServerState`
      representing the initial state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)` where `S` is a
      `tff.learning.framework.ServerState` whose type matches that of the output
      of `initialize`, and `{B*}@CLIENTS` represents the client datasets, where
      `B` is the type of a single batch. This computation returns a
      `tff.learning.framework.ServerState` representing the updated server state
      and training metrics that are the result of
      `tff.learning.Model.federated_output_computation` during client training.

  Each time the `next` method is called, the server model is broadcast to each
  client using a broadcast function. For each client, one epoch of local
  training is performed via the `tf.keras.optimizers.Optimizer.apply_gradients`
  method of the client optimizer. Each client computes the difference between
  the client model after training and the initial broadcast model. These model
  deltas are then aggregated at the server using some aggregation function. The
  aggregate model delta is applied at the server by using the
  `tf.keras.optimizers.Optimizer.apply_gradients` method of the server
  optimizer.

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedAvg algorithm in
  [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`.
      By default, this uses `tf.keras.optimizers.SGD` with a learning rate of
      1.0.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor providing the weight in
      the federated average of model deltas. If not provided, the default is the
      total number of examples processed on device.
    stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and upates state. It must have
      TFF type `(<state@SERVER, value@CLIENTS, weights@CLIENTS> ->
      <state@SERVER, aggregate@SERVER>)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`. By default performs arithmetic mean
      aggregation, weighted by `client_weight_fn`.
    stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and upates state. It must have
      TFF type `(<state@SERVER, value@SERVER> -> <state@SERVER,
      value@CLIENTS>)`, where the `value` type is
      `tff.learning.framework.ModelWeights` corresponding to the object returned
      by `model_fn`. The default is the identity broadcast.

  Returns:
    A `tff.utils.IterativeProcess`.
  """
    def client_fed_avg(model_fn):
        return ClientFedAvg(model_fn(), client_optimizer_fn(),
                            client_weight_fn)

    if stateful_delta_aggregate_fn is None:
        stateful_delta_aggregate_fn = optimizer_utils.build_stateless_mean()
    else:
        py_typecheck.check_type(stateful_delta_aggregate_fn,
                                tff.utils.StatefulAggregateFn)

    if stateful_model_broadcast_fn is None:
        stateful_model_broadcast_fn = optimizer_utils.build_stateless_broadcaster(
        )
    else:
        py_typecheck.check_type(stateful_model_broadcast_fn,
                                tff.utils.StatefulBroadcastFn)

    return optimizer_utils.build_model_delta_optimizer_process(
        model_fn, client_fed_avg, server_optimizer_fn,
        stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
Example #5
0
def build_federated_sgd_process(
        model_fn,
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
        client_weight_fn=None,
        stateful_delta_aggregate_fn=None,
        stateful_model_broadcast_fn=None):
    """Builds the TFF computations for optimization using federated SGD.

  This function creates a `tff.templates.IterativeProcess` that performs
  federated averaging on client models. The iterative process has the following
  methods:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a`tff.learning.framework.ServerState`
      representing the initial state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)` where `S` is a
      `tff.learning.framework.ServerState` whose type matches that of the output
      of `initialize`, and `{B*}@CLIENTS` represents the client datasets, where
      `B` is the type of a single batch. This computation returns a
      `tff.learning.framework.ServerState` representing the updated server state
      and training metrics that are the result of
      `tff.learning.Model.federated_output_computation` during client training.

  Each time the `next` method is called, the server model is broadcast to each
  client using a broadcast function. Each client sums the gradients at each
  batch in the client's local dataset. These gradient sums are then aggregated
  at the server using an aggregation function. The aggregate gradients are
  applied at the server by using the
  `tf.keras.optimizers.Optimizer.apply_gradients` method of the server
  optimizer.

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the aggregate of the
  gradients to the current server model. This recovers the original FedSGD
  algorithm in [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More
  sophisticated federated SGD procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of the aggregated gradients. If not provided, the
      default is the total number of examples processed on device.
    stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and upates state. It must have
      TFF type `(<state@SERVER, value@CLIENTS, weights@CLIENTS> ->
      <state@SERVER, aggregate@SERVER>)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`. By default performs arithmetic mean
      aggregation, weighted by `client_weight_fn`.
    stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and upates state. It must have
      TFF type `(<state@SERVER, value@SERVER> -> <state@SERVER,
      value@CLIENTS>)`, where the `value` type is
      `tff.learning.framework.ModelWeights` corresponding to the object returned
      by `model_fn`. The default is the identity broadcast.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    def client_sgd_avg(model_fn):
        return ClientSgd(model_fn(), client_weight_fn)

    if stateful_delta_aggregate_fn is None:
        stateful_delta_aggregate_fn = optimizer_utils.build_stateless_mean()
    else:
        py_typecheck.check_type(stateful_delta_aggregate_fn,
                                tff.utils.StatefulAggregateFn)

    if stateful_model_broadcast_fn is None:
        stateful_model_broadcast_fn = optimizer_utils.build_stateless_broadcaster(
        )
    else:
        py_typecheck.check_type(stateful_model_broadcast_fn,
                                tff.utils.StatefulBroadcastFn)

    return optimizer_utils.build_model_delta_optimizer_process(
        model_fn, client_sgd_avg, server_optimizer_fn,
        stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
Example #6
0
def build_federated_averaging_process(
    model_fn: Callable[[], model_lib.Model],
    client_optimizer_fn: Optional[Callable[
        [], tf.keras.optimizers.Optimizer]] = None,
    server_optimizer_fn: Callable[
        [], tf.keras.optimizers.Optimizer] = DEFAULT_SERVER_OPTIMIZER_FN,
    client_weight_fn: Callable[[Any], tf.Tensor] = None,
    stateful_delta_aggregate_fn=None,
    stateful_model_broadcast_fn=None) -> tff.utils.IterativeProcess:
  """Builds the TFF computations for optimization using federated averaging.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: An optional no-arg callable that returns a
      `tf.keras.Optimizer`
    server_optimizer_fn: A no-arg callable that returns a `tf.keras.Optimizer`.
      The `apply_gradients` method of this optimizer is used to apply client
      updates to the server model. The default creates a
      `tf.keras.optimizers.SGD` with a learning rate of 1.0, which simply adds
      the average client delta to the server's model.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.
    stateful_delta_aggregate_fn: A `tff.utils.StatefulAggregateFn` where the
      `next_fn` performs a federated aggregation and upates state. That is, it
      has TFF type `(state@SERVER, value@CLIENTS, weights@CLIENTS) ->
      (state@SERVER, aggregate@SERVER)`, where the `value` type is
      `tff.learning.framework.ModelWeights.trainable` corresponding to the
      object returned by `model_fn`. By default performs arithmetic mean
      aggregation, weighted by `client_weight_fn`.
    stateful_model_broadcast_fn: A `tff.utils.StatefulBroadcastFn` where the
      `next_fn` performs a federated broadcast and upates state. That is, it has
      TFF type `(state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS)`,
      where the `value` type is `tff.learning.framework.ModelWeights`
      corresponding to the object returned by `model_fn`. By default performs
      identity broadcast.

  Returns:
    A `tff.utils.IterativeProcess`.
  """
  if client_optimizer_fn is None:
    warnings.warn('tff.learning.build_federated_averaging_process will start '
                  'requiring a new argument \'client_optimizer_fn\'. Specify '
                  'the local client optimizer here rather than building a '
                  'ttf.learning.TrainableModel')
  else:
    # Validate parameters and surfacing errors early requires building a
    # throwaway model here.
    with tf.Graph().as_default():
      model = model_fn()
      if isinstance(model, model_lib.TrainableModel):
        raise TypeError('model_fn parameter should be a callable that produces '
                        'tff.learning.Model, not the deprecated '
                        'tff.learning.TrainableModel')

  def client_fed_avg(model_fn):
    if client_optimizer_fn is None:
      return _DeprecatedClientFedAvg(model_fn(), client_weight_fn)
    elif callable(client_optimizer_fn):
      return _ClientFedAvg(model_fn(), client_optimizer_fn(), client_weight_fn)
    else:
      raise TypeError(f'client_optimizer_fn parameter of '
                      'tff.learning.build_federated_averaging_process must be '
                      'a callable. Received a {type(client_optimizer_fn)}')

  if stateful_delta_aggregate_fn is None:
    stateful_delta_aggregate_fn = optimizer_utils.build_stateless_mean()
  else:
    py_typecheck.check_type(stateful_delta_aggregate_fn,
                            tff.utils.StatefulAggregateFn)

  if stateful_model_broadcast_fn is None:
    stateful_model_broadcast_fn = optimizer_utils.build_stateless_broadcaster()
  else:
    py_typecheck.check_type(stateful_model_broadcast_fn,
                            tff.utils.StatefulBroadcastFn)

  return optimizer_utils.build_model_delta_optimizer_process(
      model_fn, client_fed_avg, server_optimizer_fn,
      stateful_delta_aggregate_fn, stateful_model_broadcast_fn)