Example #1
0
  def test_type_properties(self):
    mw_type = computation_types.to_type(
        model_utils.ModelWeights(
            trainable=(tf.float32, tf.float32), non_trainable=tf.float32))

    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0), mw_type)
    self.assertIsInstance(finalizer, finalizers.FinalizerProcess)

    expected_param_weights_type = computation_types.at_server(mw_type)
    expected_param_update_type = computation_types.at_server(mw_type.trainable)
    expected_result_type = computation_types.at_server(mw_type)
    expected_state_type = computation_types.at_server(
        computation_types.to_type(
            collections.OrderedDict([(optimizer_base.LEARNING_RATE_KEY,
                                      tf.float32)])))
    expected_measurements_type = computation_types.at_server(())

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    expected_initialize_type.check_equivalent_to(
        finalizer.initialize.type_signature)

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type,
            weights=expected_param_weights_type,
            update=expected_param_update_type),
        result=MeasuredProcessOutput(expected_state_type, expected_result_type,
                                     expected_measurements_type))
    expected_next_type.check_equivalent_to(finalizer.next.type_signature)
Example #2
0
  def test_execution_with_stateless_tff_optimizer(self):
    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0), MODEL_WEIGHTS_TYPE.member)

    weights = model_utils.ModelWeights(1.0, ())
    update = 0.1
    optimizer_state = finalizer.initialize()
    for i in range(5):
      output = finalizer.next(optimizer_state, weights, update)
      optimizer_state = output.state
      weights = output.result
      self.assertEqual(1.0, optimizer_state[optimizer_base.LEARNING_RATE_KEY])
      self.assertAllClose(1.0 - 0.1 * (i + 1), weights.trainable)
      self.assertEqual((), output.measurements)
Example #3
0
  def test_execution_with_stateful_tff_optimizer(self):
    momentum = 0.5
    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0, momentum=momentum), MODEL_WEIGHTS_TYPE.member)

    weights = model_utils.ModelWeights(1.0, ())
    update = 0.1
    expected_velocity = 0.0
    optimizer_state = finalizer.initialize()
    for _ in range(5):
      output = finalizer.next(optimizer_state, weights, update)
      optimizer_state = output.state
      expected_velocity = expected_velocity * momentum + update
      self.assertNear(expected_velocity, optimizer_state['accumulator'], 1e-6)
      self.assertAllClose(weights.trainable - expected_velocity,
                          output.result.trainable)
      self.assertEqual((), output.measurements)
    weights = output.result
Example #4
0
  def test_execution_with_nearly_stateless_keras_optimizer(self):
    server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
    # Note that SGD only maintains a counter of how many times it has been
    # called. No other state is used.
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, MODEL_WEIGHTS_TYPE.member)

    weights = model_utils.ModelWeights(1.0, ())
    update = 0.1
    optimizer_state = finalizer.initialize()
    for i in range(5):
      output = finalizer.next(optimizer_state, weights, update)
      optimizer_state = output.state
      weights = output.result
      # We check that the optimizer state is the number of calls.
      self.assertEqual([i + 1], optimizer_state)
      self.assertAllClose(1.0 - 0.1 * (i + 1), weights.trainable)
      self.assertEqual((), output.measurements)
Example #5
0
  def test_execution_with_stateful_keras_optimizer(self):
    momentum = 0.5

    def server_optimizer_fn():
      return tf.keras.optimizers.SGD(learning_rate=1.0, momentum=0.5)

    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, MODEL_WEIGHTS_TYPE.member)

    weights = model_utils.ModelWeights(1.0, ())
    update = 0.1
    expected_velocity = 0.0
    optimizer_state = finalizer.initialize()
    for i in range(5):
      output = finalizer.next(optimizer_state, weights, update)
      optimizer_state = output.state
      expected_velocity = expected_velocity * momentum + update
      # Keras stores the negative of the velocity term used by
      # tff.learning.optimizers.SGDM
      self.assertAllClose([i + 1, -expected_velocity], optimizer_state)
      self.assertAllClose(weights.trainable - expected_velocity,
                          output.result.trainable)
      self.assertEqual((), output.measurements)
      weights = output.result
Example #6
0
def build_basic_fedavg_process(model_fn: Callable[[], model_lib.Model],
                               client_learning_rate: float):
    """Builds vanilla Federated Averaging process.

  The created process is the basic form of the Federated Averaging algorithm as
  proposed by http://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf in
  Algorithm 1, for training the model created by `model_fn`. The following is
  the algorithm in pseudo-code:

  ```
  # Inputs: m: Initial model weights; eta: Client learning rate
  for i in num_rounds:
    for c in available_clients_indices:
      delta_m_c, w_c = client_update(m, eta)
    aggregate_model_delta = sum_c(model_delta_c * w_c) / sum_c(w_c)
    m = m - aggregate_model_delta
  return m  # Final trained model.

  def client_udpate(m, eta):
    initial_m = m
    for batch in client_dataset:
      m = m - eta * grad(m, b)
    delta_m = initial_m - m
    return delta_m, size(dataset)
  ```

  The other algorithm hyper parameters (batch size, number of local epochs) are
  controlled by the data provided to the built process.

  An example usage of the returned `LearningProcess` in simulation:

  ```
  fedavg = build_basic_fedavg_process(model_fn, 0.1)

  # Create a `LearningAlgorithmState` containing the initial model weights for
  # the model returned from `model_fn`.
  state = fedavg.initialize()
  for _ in range(num_rounds):
    client_data = ...  # Preprocessed client datasets
    output = fedavg.next(state, client_data)
    write_round_metrics(outpus.metrics)
    # The new state contains the updated model weights after this round.
    state = output.state
  ```

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_learning_rate: A float. Learning rate for the SGD at clients.

  Returns:
    A `LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(client_learning_rate, float)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    distributor = distributors.build_broadcast_process(model_weights_type)
    client_work = model_delta_client_work.build_model_delta_client_work(
        model_fn,
        sgdm.build_sgdm(client_learning_rate),
        client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES)
    aggregator = mean.MeanFactory().create(
        client_work.next.type_signature.result.result.member.update,
        client_work.next.type_signature.result.result.member.update_weight)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0), model_weights_type)

    return compose_learning_process(initial_model_weights_fn, distributor,
                                    client_work, aggregator, finalizer)
Example #7
0
def build_weighted_mime_lite(
    model_fn: Callable[[], model_lib.Model],
    base_optimizer: optimizer_base.Optimizer,
    server_optimizer: optimizer_base.Optimizer = sgdm.build_sgdm(1.0),
    client_weighting: Optional[
        client_weight_lib.
        ClientWeighting] = client_weight_lib.ClientWeighting.NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    full_gradient_aggregator: Optional[
        factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs Mime Lite.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  Mime Lite algorithm on client models. The iterative process has the following
  methods inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `optimizer`, where its state is communicated by
  the server, and kept intact during local training. The state is updated only
  at the server based on the full gradient evaluated by the clients based on the
  current server model state. The client full gradients are aggregated by
  weighted `full_gradient_aggregator`. Each client computes the difference
  between the client model after training and its initial model. These model
  deltas are then aggregated by weighted `model_aggregator`. Both of the
  aggregations are weighted, according to `client_weighting`. The aggregate
  model delta is added to the existing server model state.

  The Mime Lite algorithm is based on the paper
  "Breaking the centralized barrier for cross-device federated learning."
    Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank
    Reddi, Sebastian U. Stich, and Ananda Theertha Suresh.
    Advances in Neural Information Processing Systems 34 (2021).
    https://proceedings.neurips.cc/paper/2021/file/f0e6be4ce76ccfa73c5a540d992d0756-Paper.pdf

  Note that Keras optimizers are not supported. This is due to the Mime Lite
  algorithm applying the optimizer without changing it state at clients
  (optimizer's `tf.Variable`s in the case of Keras), which is not possible with
  Keras optimizers without reaching into private implementation details and
  incurring additional computation and memory cost at clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    base_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for
      both creating and updating a global optimizer state, as well as
      optimization at clients given the global state, which is fixed during the
      optimization.
    server_optimizer: A `tff.learning.optimizers.Optimizer` which will be used
      for applying the aggregate model update to the global model weights.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    full_gradient_aggregator: An optional
      `tff.aggregators.WeightedAggregationFactory` used to aggregate the full
      gradients on client datasets. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(base_optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(server_optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result
    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)
    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)
    model_aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))
    if full_gradient_aggregator is None:
        full_gradient_aggregator = mean.MeanFactory()
    py_typecheck.check_type(full_gradient_aggregator,
                            factory.WeightedAggregationFactory)

    client_work = _build_mime_lite_client_work(
        model_fn=model_fn,
        optimizer=base_optimizer,
        client_weighting=client_weighting,
        full_gradient_aggregator=full_gradient_aggregator,
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              model_aggregator, finalizer)
def build_weighted_fed_avg_with_optimizer_schedule(
    model_fn: Callable[[], model_lib.Model],
    client_learning_rate_fn: Callable[[int], float],
    client_optimizer_fn: Callable[[float], TFFOrKerasOptimizer],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [],
        tf.keras.optimizers.Optimizer]] = fed_avg.DEFAULT_SERVER_OPTIMIZER_FN,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
  """Builds a learning process for FedAvg with client optimizer scheduling.

  This function creates a `LearningProcess` that performs federated averaging on
  client models. The iterative process has the following methods inherited from
  `LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a `LearningAlgorithmState` representing the
      initial state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is broadcast to each
  client using a broadcast function. For each client, local training is
  performed using `client_optimizer_fn`. Each client computes the difference
  between the client model after training and the initial broadcast model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function. Clients weighted by the number of examples they see
  thoughout local training. The aggregate model delta is applied at the server
  using a server optimizer.

  The primary purpose of this implementation of FedAvg is that it allows for the
  client optimizer to be scheduled across rounds. The process keeps track of how
  many iterations of `.next` have occurred (starting at `0`), and for each such
  `round_num`, the clients will use `client_optimizer_fn(round_num)` to perform
  local optimization. This allows learning rate scheduling (eg. starting with
  a large learning rate and decaying it over time) as well as a small learning
  rate (eg. switching optimizers as learning progresses).

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedAvg algorithm in
  [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    client_learning_rate_fn: A callable accepting an integer round number and
      returning a float to be used as a learning rate for the optimizer. The
      client work will call `optimizer_fn(learning_rate_fn(round_num))` where
      `round_num` is the integer round number. Note that the round numbers
      supplied will start at `0` and increment by one each time `.next` is
      called on the resulting process. Also note that this function must be
      serializable by TFF.
    client_optimizer_fn: A callable accepting a float learning rate, and
      returning a `tff.learning.optimizers.Optimizer` or a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `LearningProcess`.
  """
  py_typecheck.check_callable(model_fn)

  @tensorflow_computation.tf_computation()
  def initial_model_weights_fn():
    return model_utils.ModelWeights.from_model(model_fn())

  model_weights_type = initial_model_weights_fn.type_signature.result

  if model_distributor is None:
    model_distributor = distributors.build_broadcast_process(model_weights_type)

  if model_aggregator is None:
    model_aggregator = mean.MeanFactory()
  py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory)
  aggregator = model_aggregator.create(model_weights_type.trainable,
                                       computation_types.TensorType(tf.float32))
  process_signature = aggregator.next.type_signature
  input_client_value_type = process_signature.parameter[1]
  result_server_value_type = process_signature.result[1]
  if input_client_value_type.member != result_server_value_type.member:
    raise TypeError('`model_update_aggregation_factory` does not produce a '
                    'compatible `AggregationProcess`. The processes must '
                    'retain the type structure of the inputs on the '
                    f'server, but got {input_client_value_type.member} != '
                    f'{result_server_value_type.member}.')

  if metrics_aggregator is None:
    metrics_aggregator = metric_aggregator.sum_then_finalize
  client_work = build_scheduled_client_work(model_fn, client_learning_rate_fn,
                                            client_optimizer_fn,
                                            metrics_aggregator,
                                            use_experimental_simulation_loop)
  finalizer = finalizers.build_apply_optimizer_finalizer(
      server_optimizer_fn, model_weights_type)
  return composers.compose_learning_process(initial_model_weights_fn,
                                            model_distributor, client_work,
                                            aggregator, finalizer)
Example #9
0
 def test_unexpected_optimizer_fn_raises(self):
   optimizer = tf.keras.optimizers.SGD(1.0)
   with self.assertRaises(TypeError):
     finalizers.build_apply_optimizer_finalizer(optimizer,
                                                MODEL_WEIGHTS_TYPE.member)
Example #10
0
 def test_incorrect_value_type_raises(self, bad_type):
   with self.assertRaises(TypeError):
     finalizers.build_apply_optimizer_finalizer(sgdm.build_sgdm(1.0), bad_type)
Example #11
0
def build_fed_sgd(
    model_fn: Callable[[], model_lib.Model],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False,
) -> learning_process.LearningProcess:
    """Builds a learning process that performs federated SGD.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  federated SGD on client models. The learning process has the following methods
  inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with type signature `( -> S@SERVER)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState`
      representing the initial state of the server.
  *   `next`: A `tff.Computation` with type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)` where `S` is a
      `LearningAlgorithmState` whose type matches that of the output
      of `initialize`, and `{B*}@CLIENTS` represents the client datasets, where
      `B` is the type of a single batch. This computation returns a
      `LearningAlgorithmState` representing the updated server state and the
      metrics during client training and any other metrics from broadcast and
      aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time `next` is called, the server model is broadcast to each client using
  a distributor. Each client sums the gradients for each batch in its local
  dataset (without updating its model) to calculate, and averages the gradients
  based on their number of examples. These average gradients are then aggregated
  at the server, and are applied at the server using a
  `tf.keras.optimizers.Optimizer`.

  This implements the original FedSGD algorithm in [McMahan et al.,
  2017](https://arxiv.org/abs/1602.05629).

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. The optimizer is used to
      apply client updates to the server model.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)

    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))

    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize
    client_work = _build_fed_sgd_client_work(
        model_fn,
        metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              aggregator, finalizer)
Example #12
0
def build_weighted_fed_prox(
    model_fn: Callable[[], model_lib.Model],
    proximal_strength: float,
    client_optimizer_fn: Union[optimizer_base.Optimizer,
                               Callable[[], tf.keras.optimizers.Optimizer]],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    client_weighting: Optional[
        client_weight_lib.ClientWeighting] = client_weight_lib.ClientWeighting.
    NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs the FedProx algorithm.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  example-weighted FedProx on client models. This algorithm behaves the same as
  federated averaging, except that it uses a proximal regularization term that
  encourages clients to not drift too far from the server model.

  The iterative process has the following methods inherited from
  `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize`and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `client_optimizer_fn`. Each client computes the
  difference between the client model after training and the initial model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function, according to `client_weighting`. The aggregate model
  delta is applied at the server using a server optimizer, as in the FedOpt
  framework proposed in [Reddi et al., 2021](https://arxiv.org/abs/2003.00295).

  Note: The default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedProx algorithm in
  [Li et al., 2020](https://arxiv.org/abs/1812.06127). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    proximal_strength: A nonnegative float representing the parameter of
      FedProx's regularization term. When set to `0`, the algorithm reduces to
      FedAvg. Higher values prevent clients from moving too far from the server
      model during local training.
    client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that broadcasts the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.

  Raises:
    ValueError: If `proximal_parameter` is not a nonnegative float.
  """
    if not isinstance(proximal_strength, float) or proximal_strength < 0.0:
        raise ValueError(
            'proximal_strength must be a nonnegative float, found {}'.format(
                proximal_strength))

    py_typecheck.check_callable(model_fn)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)

    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)
    aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))
    process_signature = aggregator.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
        raise TypeError(
            '`model_update_aggregation_factory` does not produce a '
            'compatible `AggregationProcess`. The processes must '
            'retain the type structure of the inputs on the '
            f'server, but got {input_client_value_type.member} != '
            f'{result_server_value_type.member}.')

    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize
    client_work = model_delta_client_work.build_model_delta_client_work(
        model_fn=model_fn,
        optimizer=client_optimizer_fn,
        client_weighting=client_weighting,
        delta_l2_regularizer=proximal_strength,
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              aggregator, finalizer)
Example #13
0
def build_weighted_fed_avg(
    model_fn: Callable[[], model_lib.Model],
    client_optimizer_fn: Union[optimizer_base.Optimizer,
                               Callable[[], tf.keras.optimizers.Optimizer]],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    *,
    client_weighting: Optional[
        client_weight_lib.ClientWeighting] = client_weight_lib.ClientWeighting.
    NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False,
    model: Optional[functional.FunctionalModel] = None,
) -> learning_process.LearningProcess:
    """Builds a learning process that performs federated averaging.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  federated averaging on client models. The iterative process has the following
  methods inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `client_optimizer_fn`. Each client computes the
  difference between the client model after training and its initial model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function, according to `client_weighting`. The aggregate model
  delta is applied at the server using a server optimizer.

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedAvg algorithm in
  [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers (this generalized FedAvg algorithm is described in
  [Reddi et al., 2021](https://arxiv.org/abs/2003.00295)).

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error. Cannot
      be used with `model` argument.
    client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via
      `tff.learning.templates.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.
    model: A `tff.learning.models.FunctionalModel` to train. Cannot be used with
      `model_fn` and must be `None` if `model_fn` is not `None`.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    if model is not None and model_fn is not None:
        raise ValueError(
            'Must specify only one of `model` and `model_fn`, both '
            'were not `None`.')
    elif model is None and model_fn is None:
        raise ValueError(
            'Must specify one of `model` and `model_fn`, both were '
            '`None`.')
    if model_fn is not None:
        py_typecheck.check_callable(model_fn)
    else:
        py_typecheck.check_type(model, functional.FunctionalModel)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)

    if model is not None:

        @tensorflow_computation.tf_computation()
        def initial_model_weights_fn():
            return model_utils.ModelWeights(
                tuple(
                    tf.convert_to_tensor(w) for w in model.initial_weights[0]),
                tuple(
                    tf.convert_to_tensor(w) for w in model.initial_weights[1]))
    else:

        @tensorflow_computation.tf_computation()
        def initial_model_weights_fn():
            return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)

    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)

    if model is not None:
        trainable_weights_type, _ = model_weights_type
        model_update_type = trainable_weights_type
    else:
        model_update_type = model_weights_type.trainable
    aggregator = model_aggregator.create(
        model_update_type, computation_types.TensorType(tf.float32))

    process_signature = aggregator.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
        raise TypeError(
            '`model_aggregator` does not produce a compatible '
            '`AggregationProcess`. The processes must retain the type '
            'structure of the inputs on the server, but got '
            f'{input_client_value_type.member} != '
            f'{result_server_value_type.member}.')

    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize

    if model is not None:
        client_work = model_delta_client_work.build_functional_model_delta_client_work(
            model=model,
            optimizer=client_optimizer_fn,
            client_weighting=client_weighting,
            metrics_aggregator=metrics_aggregator)
    else:
        client_work = model_delta_client_work.build_model_delta_client_work(
            model_fn=model_fn,
            optimizer=client_optimizer_fn,
            client_weighting=client_weighting,
            metrics_aggregator=metrics_aggregator,
            use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              aggregator, finalizer)