Example #1
0
 def test_unweighted_aggregator_raises(self):
     bad_aggregator = sum_factory.SumFactory().create(FLOAT_TYPE)
     with self.assertRaisesRegex(TypeError, 'weighted'):
         composers.compose_learning_process(test_init_model_weights_fn,
                                            test_distributor(),
                                            test_client_work(),
                                            bad_aggregator,
                                            test_finalizer())
Example #2
0
 def test_not_finalizer_type_raises(self):
     finalizer = test_finalizer()
     bad_finalizer = measured_process.MeasuredProcess(
         finalizer.initialize, finalizer.next)
     with self.assertRaisesRegex(TypeError, 'FinalizerProcess'):
         composers.compose_learning_process(test_init_model_weights_fn,
                                            test_distributor(),
                                            test_client_work(),
                                            test_aggregator(),
                                            bad_finalizer)
Example #3
0
    def test_not_tff_computation_init_raises(self):
        def init_model_weights_fn():
            return model_utils.ModelWeights(trainable=tf.constant(1.0),
                                            non_trainable=())

        with self.assertRaisesRegex(TypeError, 'Computation'):
            composers.compose_learning_process(init_model_weights_fn,
                                               test_distributor(),
                                               test_client_work(),
                                               test_aggregator(),
                                               test_finalizer())
Example #4
0
    def test_federated_init_raises(self):
        @federated_computation.federated_computation()
        def init_model_weights_fn():
            return intrinsics.federated_eval(test_init_model_weights_fn,
                                             placements.SERVER)

        with self.assertRaisesRegex(TypeError, 'unplaced'):
            composers.compose_learning_process(init_model_weights_fn,
                                               test_distributor(),
                                               test_client_work(),
                                               test_aggregator(),
                                               test_finalizer())
Example #5
0
    def test_one_arg_computation_init_raises(self):
        @tensorflow_computation.tf_computation(
            computation_types.TensorType(tf.float32))
        def init_model_weights_fn(x):
            return model_utils.ModelWeights(trainable=x, non_trainable=())

        with self.assertRaisesRegex(TypeError, 'Computation'):
            composers.compose_learning_process(init_model_weights_fn,
                                               test_distributor(),
                                               test_client_work(),
                                               test_aggregator(),
                                               test_finalizer())
Example #6
0
    def test_learning_process_composes(self):
        process = composers.compose_learning_process(
            test_init_model_weights_fn, test_distributor(), test_client_work(),
            test_aggregator(), test_finalizer())

        self.assertIsInstance(process, learning_process.LearningProcess)
        self.assertEqual(
            process.initialize.type_signature.result.member.python_container,
            composers.LearningAlgorithmState)
        self.assertEqual(
            process.initialize.type_signature.result.member.
            global_model_weights, MODEL_WEIGHTS_TYPE)

        # Reported metrics have the expected fields.
        metrics_type = process.next.type_signature.result.metrics.member
        self.assertTrue(structure.has_field(metrics_type, 'distributor'))
        self.assertTrue(structure.has_field(metrics_type, 'client_work'))
        self.assertTrue(structure.has_field(metrics_type, 'aggregator'))
        self.assertTrue(structure.has_field(metrics_type, 'finalizer'))
        self.assertLen(metrics_type, 4)
Example #7
0
def build_weighted_mime_lite(
    model_fn: Callable[[], model_lib.Model],
    base_optimizer: optimizer_base.Optimizer,
    server_optimizer: optimizer_base.Optimizer = sgdm.build_sgdm(1.0),
    client_weighting: Optional[
        client_weight_lib.
        ClientWeighting] = client_weight_lib.ClientWeighting.NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    full_gradient_aggregator: Optional[
        factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs Mime Lite.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  Mime Lite algorithm on client models. The iterative process has the following
  methods inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `optimizer`, where its state is communicated by
  the server, and kept intact during local training. The state is updated only
  at the server based on the full gradient evaluated by the clients based on the
  current server model state. The client full gradients are aggregated by
  weighted `full_gradient_aggregator`. Each client computes the difference
  between the client model after training and its initial model. These model
  deltas are then aggregated by weighted `model_aggregator`. Both of the
  aggregations are weighted, according to `client_weighting`. The aggregate
  model delta is added to the existing server model state.

  The Mime Lite algorithm is based on the paper
  "Breaking the centralized barrier for cross-device federated learning."
    Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank
    Reddi, Sebastian U. Stich, and Ananda Theertha Suresh.
    Advances in Neural Information Processing Systems 34 (2021).
    https://proceedings.neurips.cc/paper/2021/file/f0e6be4ce76ccfa73c5a540d992d0756-Paper.pdf

  Note that Keras optimizers are not supported. This is due to the Mime Lite
  algorithm applying the optimizer without changing it state at clients
  (optimizer's `tf.Variable`s in the case of Keras), which is not possible with
  Keras optimizers without reaching into private implementation details and
  incurring additional computation and memory cost at clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    base_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for
      both creating and updating a global optimizer state, as well as
      optimization at clients given the global state, which is fixed during the
      optimization.
    server_optimizer: A `tff.learning.optimizers.Optimizer` which will be used
      for applying the aggregate model update to the global model weights.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    full_gradient_aggregator: An optional
      `tff.aggregators.WeightedAggregationFactory` used to aggregate the full
      gradients on client datasets. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(base_optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(server_optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result
    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)
    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)
    model_aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))
    if full_gradient_aggregator is None:
        full_gradient_aggregator = mean.MeanFactory()
    py_typecheck.check_type(full_gradient_aggregator,
                            factory.WeightedAggregationFactory)

    client_work = _build_mime_lite_client_work(
        model_fn=model_fn,
        optimizer=base_optimizer,
        client_weighting=client_weighting,
        full_gradient_aggregator=full_gradient_aggregator,
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              model_aggregator, finalizer)
def build_weighted_fed_avg_with_optimizer_schedule(
    model_fn: Callable[[], model_lib.Model],
    client_learning_rate_fn: Callable[[int], float],
    client_optimizer_fn: Callable[[float], TFFOrKerasOptimizer],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [],
        tf.keras.optimizers.Optimizer]] = fed_avg.DEFAULT_SERVER_OPTIMIZER_FN,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
  """Builds a learning process for FedAvg with client optimizer scheduling.

  This function creates a `LearningProcess` that performs federated averaging on
  client models. The iterative process has the following methods inherited from
  `LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a `LearningAlgorithmState` representing the
      initial state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is broadcast to each
  client using a broadcast function. For each client, local training is
  performed using `client_optimizer_fn`. Each client computes the difference
  between the client model after training and the initial broadcast model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function. Clients weighted by the number of examples they see
  thoughout local training. The aggregate model delta is applied at the server
  using a server optimizer.

  The primary purpose of this implementation of FedAvg is that it allows for the
  client optimizer to be scheduled across rounds. The process keeps track of how
  many iterations of `.next` have occurred (starting at `0`), and for each such
  `round_num`, the clients will use `client_optimizer_fn(round_num)` to perform
  local optimization. This allows learning rate scheduling (eg. starting with
  a large learning rate and decaying it over time) as well as a small learning
  rate (eg. switching optimizers as learning progresses).

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedAvg algorithm in
  [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    client_learning_rate_fn: A callable accepting an integer round number and
      returning a float to be used as a learning rate for the optimizer. The
      client work will call `optimizer_fn(learning_rate_fn(round_num))` where
      `round_num` is the integer round number. Note that the round numbers
      supplied will start at `0` and increment by one each time `.next` is
      called on the resulting process. Also note that this function must be
      serializable by TFF.
    client_optimizer_fn: A callable accepting a float learning rate, and
      returning a `tff.learning.optimizers.Optimizer` or a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `LearningProcess`.
  """
  py_typecheck.check_callable(model_fn)

  @tensorflow_computation.tf_computation()
  def initial_model_weights_fn():
    return model_utils.ModelWeights.from_model(model_fn())

  model_weights_type = initial_model_weights_fn.type_signature.result

  if model_distributor is None:
    model_distributor = distributors.build_broadcast_process(model_weights_type)

  if model_aggregator is None:
    model_aggregator = mean.MeanFactory()
  py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory)
  aggregator = model_aggregator.create(model_weights_type.trainable,
                                       computation_types.TensorType(tf.float32))
  process_signature = aggregator.next.type_signature
  input_client_value_type = process_signature.parameter[1]
  result_server_value_type = process_signature.result[1]
  if input_client_value_type.member != result_server_value_type.member:
    raise TypeError('`model_update_aggregation_factory` does not produce a '
                    'compatible `AggregationProcess`. The processes must '
                    'retain the type structure of the inputs on the '
                    f'server, but got {input_client_value_type.member} != '
                    f'{result_server_value_type.member}.')

  if metrics_aggregator is None:
    metrics_aggregator = metric_aggregator.sum_then_finalize
  client_work = build_scheduled_client_work(model_fn, client_learning_rate_fn,
                                            client_optimizer_fn,
                                            metrics_aggregator,
                                            use_experimental_simulation_loop)
  finalizer = finalizers.build_apply_optimizer_finalizer(
      server_optimizer_fn, model_weights_type)
  return composers.compose_learning_process(initial_model_weights_fn,
                                            model_distributor, client_work,
                                            aggregator, finalizer)
Example #9
0
def build_fed_kmeans(
    num_clusters: int,
    data_shape: Tuple[int, ...],
    random_seed: Optional[Tuple[int, int]] = None,
    distributor: Optional[distributors.DistributionProcess] = None,
    sum_aggregator: Optional[factory.UnweightedAggregationFactory] = None,
) -> learning_process.LearningProcess:
    """Builds a learning process for federated k-means clustering.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  federated k-means clustering. Specifically, this performs mini-batch k-means
  clustering. Note that mini-batch k-means only processes a mini-batch of the
  data at each round, and updates clusters in a weighted manner based on how
  many points in the mini-batch were assigned to each cluster. In the federated
  version, clients do the assignment of each of their point locally, and the
  server updates the clusters. Conceptually, the "mini-batch" being used is the
  union of all client datasets involved in a given round.

  The learning process has the following methods inherited from
  `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a `LearningAlgorithmState` representing the
      initial state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `LearningAlgorithmState` whose type matches the output of `initialize`
      and `{B*}@CLIENTS` represents the client datasets. The output `L` is a
      `tff.learning.templates.LearningProcessOutput` containing the state `S`
      and metrics computed during training.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> W)`,
      where `W` represents the current k-means centroids.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` a new set of k-means centroids.

  Here, `S` is a `tff.learning.templates.LearningAlgorithmState`. The centroids
  `W` is a tensor representing the current centroids, and is of shape
  `(num_clusters,) + data_shape`. The datasets `{B*}` must have elements of
  shape `data_shape`, and not employ batching.

  The centroids are updated at each round by assigning all clients' points to
  the nearest centroid, and then summing these points according to these
  centroids. The centroids are then updated at the server based on these points.
  To do so, we keep track of how many points have been assigned to each centroid
  overall, as an integer tensor of shape `(num_clusters,)`. This information can
  be found in `state.finalizer`. Note that we begin with a "pseudo-count" of 1,
  in order to ensure that the centroids do not collapse to zero.

  Args:
    num_clusters: The number of clusters to use.
    data_shape: A tuple of integers specifying the shape of each data point.
      Note that this data shape should be unbatched, as this algorithm does not
      currently support batched data points.
    random_seed: A tuple of two integers used to seed the initialization phase.
    distributor: An optional `tff.learning.tekmplates.DistributionProcess` that
      broadcasts the centroids on the server to the clients. If set to `None`,
      the distributor is constructed via
      `tff.learning.templates.build_broadcast_process`.
    sum_aggregator: An optional `tff.aggregators.UnweightedAggregationFactory`
      used to sum updates across clients. If `None`, we use
      `tff.aggregators.SumFactory`.

  Returns:
    A `LearningProcess`.
  """
    centroids_shape = (num_clusters, ) + data_shape

    if not random_seed:
        random_seed = (tf.cast(tf.timestamp() * _MILLIS_PER_SECOND,
                               tf.int64).numpy(), 0)

    @tensorflow_computation.tf_computation
    def initialize_centers():
        return tf.random.stateless_normal(centroids_shape,
                                          random_seed,
                                          dtype=_POINT_DTYPE)

    centroids_type = computation_types.TensorType(_POINT_DTYPE,
                                                  centroids_shape)
    weights_type = computation_types.TensorType(_WEIGHT_DTYPE,
                                                shape=(num_clusters, ))
    point_type = computation_types.TensorType(_POINT_DTYPE, shape=data_shape)
    data_type = computation_types.SequenceType(point_type)

    if distributor is None:
        distributor = distributors.build_broadcast_process(centroids_type)

    client_work = _build_kmeans_client_work(centroids_type, data_type)

    if sum_aggregator is None:
        sum_aggregator = sum_factory.SumFactory()
    # We wrap the sum factory as a weighted aggregator for compatibility with
    # the learning process composer.
    weighted_aggregator = factory_utils.as_weighted_aggregator(sum_aggregator)
    value_type = computation_types.to_type((centroids_type, weights_type))
    aggregator = weighted_aggregator.create(value_type,
                                            computation_types.to_type(()))

    finalizer = _build_kmeans_finalizer(centroids_type, num_clusters)

    return composers.compose_learning_process(initialize_centers, distributor,
                                              client_work, aggregator,
                                              finalizer)
Example #10
0
def build_fed_sgd(
    model_fn: Callable[[], model_lib.Model],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False,
) -> learning_process.LearningProcess:
    """Builds a learning process that performs federated SGD.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  federated SGD on client models. The learning process has the following methods
  inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with type signature `( -> S@SERVER)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState`
      representing the initial state of the server.
  *   `next`: A `tff.Computation` with type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)` where `S` is a
      `LearningAlgorithmState` whose type matches that of the output
      of `initialize`, and `{B*}@CLIENTS` represents the client datasets, where
      `B` is the type of a single batch. This computation returns a
      `LearningAlgorithmState` representing the updated server state and the
      metrics during client training and any other metrics from broadcast and
      aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time `next` is called, the server model is broadcast to each client using
  a distributor. Each client sums the gradients for each batch in its local
  dataset (without updating its model) to calculate, and averages the gradients
  based on their number of examples. These average gradients are then aggregated
  at the server, and are applied at the server using a
  `tf.keras.optimizers.Optimizer`.

  This implements the original FedSGD algorithm in [McMahan et al.,
  2017](https://arxiv.org/abs/1602.05629).

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. The optimizer is used to
      apply client updates to the server model.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)

    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))

    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize
    client_work = _build_fed_sgd_client_work(
        model_fn,
        metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              aggregator, finalizer)
Example #11
0
def build_weighted_fed_prox(
    model_fn: Callable[[], model_lib.Model],
    proximal_strength: float,
    client_optimizer_fn: Union[optimizer_base.Optimizer,
                               Callable[[], tf.keras.optimizers.Optimizer]],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    client_weighting: Optional[
        client_weight_lib.ClientWeighting] = client_weight_lib.ClientWeighting.
    NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs the FedProx algorithm.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  example-weighted FedProx on client models. This algorithm behaves the same as
  federated averaging, except that it uses a proximal regularization term that
  encourages clients to not drift too far from the server model.

  The iterative process has the following methods inherited from
  `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize`and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `client_optimizer_fn`. Each client computes the
  difference between the client model after training and the initial model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function, according to `client_weighting`. The aggregate model
  delta is applied at the server using a server optimizer, as in the FedOpt
  framework proposed in [Reddi et al., 2021](https://arxiv.org/abs/2003.00295).

  Note: The default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedProx algorithm in
  [Li et al., 2020](https://arxiv.org/abs/1812.06127). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    proximal_strength: A nonnegative float representing the parameter of
      FedProx's regularization term. When set to `0`, the algorithm reduces to
      FedAvg. Higher values prevent clients from moving too far from the server
      model during local training.
    client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that broadcasts the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.

  Raises:
    ValueError: If `proximal_parameter` is not a nonnegative float.
  """
    if not isinstance(proximal_strength, float) or proximal_strength < 0.0:
        raise ValueError(
            'proximal_strength must be a nonnegative float, found {}'.format(
                proximal_strength))

    py_typecheck.check_callable(model_fn)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)

    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)
    aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))
    process_signature = aggregator.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
        raise TypeError(
            '`model_update_aggregation_factory` does not produce a '
            'compatible `AggregationProcess`. The processes must '
            'retain the type structure of the inputs on the '
            f'server, but got {input_client_value_type.member} != '
            f'{result_server_value_type.member}.')

    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize
    client_work = model_delta_client_work.build_model_delta_client_work(
        model_fn=model_fn,
        optimizer=client_optimizer_fn,
        client_weighting=client_weighting,
        delta_l2_regularizer=proximal_strength,
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              aggregator, finalizer)
Example #12
0
def build_weighted_fed_avg(
    model_fn: Callable[[], model_lib.Model],
    client_optimizer_fn: Union[optimizer_base.Optimizer,
                               Callable[[], tf.keras.optimizers.Optimizer]],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    *,
    client_weighting: Optional[
        client_weight_lib.ClientWeighting] = client_weight_lib.ClientWeighting.
    NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False,
    model: Optional[functional.FunctionalModel] = None,
) -> learning_process.LearningProcess:
    """Builds a learning process that performs federated averaging.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  federated averaging on client models. The iterative process has the following
  methods inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `client_optimizer_fn`. Each client computes the
  difference between the client model after training and its initial model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function, according to `client_weighting`. The aggregate model
  delta is applied at the server using a server optimizer.

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedAvg algorithm in
  [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers (this generalized FedAvg algorithm is described in
  [Reddi et al., 2021](https://arxiv.org/abs/2003.00295)).

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error. Cannot
      be used with `model` argument.
    client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via
      `tff.learning.templates.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.
    model: A `tff.learning.models.FunctionalModel` to train. Cannot be used with
      `model_fn` and must be `None` if `model_fn` is not `None`.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    if model is not None and model_fn is not None:
        raise ValueError(
            'Must specify only one of `model` and `model_fn`, both '
            'were not `None`.')
    elif model is None and model_fn is None:
        raise ValueError(
            'Must specify one of `model` and `model_fn`, both were '
            '`None`.')
    if model_fn is not None:
        py_typecheck.check_callable(model_fn)
    else:
        py_typecheck.check_type(model, functional.FunctionalModel)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)

    if model is not None:

        @tensorflow_computation.tf_computation()
        def initial_model_weights_fn():
            return model_utils.ModelWeights(
                tuple(
                    tf.convert_to_tensor(w) for w in model.initial_weights[0]),
                tuple(
                    tf.convert_to_tensor(w) for w in model.initial_weights[1]))
    else:

        @tensorflow_computation.tf_computation()
        def initial_model_weights_fn():
            return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)

    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)

    if model is not None:
        trainable_weights_type, _ = model_weights_type
        model_update_type = trainable_weights_type
    else:
        model_update_type = model_weights_type.trainable
    aggregator = model_aggregator.create(
        model_update_type, computation_types.TensorType(tf.float32))

    process_signature = aggregator.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
        raise TypeError(
            '`model_aggregator` does not produce a compatible '
            '`AggregationProcess`. The processes must retain the type '
            'structure of the inputs on the server, but got '
            f'{input_client_value_type.member} != '
            f'{result_server_value_type.member}.')

    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize

    if model is not None:
        client_work = model_delta_client_work.build_functional_model_delta_client_work(
            model=model,
            optimizer=client_optimizer_fn,
            client_weighting=client_weighting,
            metrics_aggregator=metrics_aggregator)
    else:
        client_work = model_delta_client_work.build_model_delta_client_work(
            model_fn=model_fn,
            optimizer=client_optimizer_fn,
            client_weighting=client_weighting,
            metrics_aggregator=metrics_aggregator,
            use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              aggregator, finalizer)