Exemplo n.º 1
0
 def test_created_keras_optimizer_raises(self):
     with self.assertRaises(TypeError):
         model_delta_client_work.build_model_delta_client_work(
             model_examples.LinearRegression,
             tf.keras.optimizers.SGD(1.0),
             client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES
         )
Exemplo n.º 2
0
 def test_created_model_raises(self):
     with self.assertRaises(TypeError):
         model_delta_client_work.build_model_delta_client_work(
             model_examples.LinearRegression(),
             sgdm.build_sgdm(1.0),
             client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES
         )
Exemplo n.º 3
0
    def test_delta_regularizer_yields_smaller_model_delta(self, optimizer):
        simple_process = model_delta_client_work.build_model_delta_client_work(
            self.create_model,
            optimizer,
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            delta_l2_regularizer=0.0)
        proximal_process = model_delta_client_work.build_model_delta_client_work(
            self.create_model,
            sgdm.build_sgdm(1.0),
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            delta_l2_regularizer=1.0)
        client_data = [create_test_dataset()]
        client_model_weights = [create_test_initial_weights()]

        simple_output = simple_process.next(simple_process.initialize(),
                                            client_model_weights, client_data)
        proximal_output = proximal_process.next(proximal_process.initialize(),
                                                client_model_weights,
                                                client_data)

        simple_update_norm = tf.linalg.global_norm(
            tf.nest.flatten(simple_output.result[0].update))
        proximal_update_norm = tf.linalg.global_norm(
            tf.nest.flatten(proximal_output.result[0].update))
        self.assertGreater(simple_update_norm, proximal_update_norm)

        self.assertEqual(simple_output.measurements['train']['num_examples'],
                         proximal_output.measurements['train']['num_examples'])
Exemplo n.º 4
0
 def test_negative_proximal_strength_raises(self):
     with self.assertRaises(ValueError):
         model_delta_client_work.build_model_delta_client_work(
             model_examples.LinearRegression,
             sgdm.build_sgdm(1.0),
             client_weighting=client_weight_lib.ClientWeighting.
             NUM_EXAMPLES,
             delta_l2_regularizer=-1.0)
Exemplo n.º 5
0
    def test_custom_metrics_aggregator(self):
        def sum_then_finalize_then_times_two(metric_finalizers,
                                             local_unfinalized_metrics_type):
            @federated_computation.federated_computation(
                computation_types.at_clients(local_unfinalized_metrics_type))
            def aggregation_computation(client_local_unfinalized_metrics):
                unfinalized_metrics_sum = intrinsics.federated_sum(
                    client_local_unfinalized_metrics)

                @tensorflow_computation.tf_computation(
                    local_unfinalized_metrics_type)
                def finalizer_computation(unfinalized_metrics):
                    finalized_metrics = collections.OrderedDict()
                    for metric_name, metric_finalizer in metric_finalizers.items(
                    ):
                        finalized_metrics[metric_name] = metric_finalizer(
                            unfinalized_metrics[metric_name]) * 2
                    return finalized_metrics

                return intrinsics.federated_map(finalizer_computation,
                                                unfinalized_metrics_sum)

            return aggregation_computation

        process = model_delta_client_work.build_model_delta_client_work(
            model_fn=self.create_model,
            optimizer=sgdm.build_sgdm(1.0),
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            metrics_aggregator=sum_then_finalize_then_times_two)
        client_model_weights = [create_test_initial_weights()]
        client_data = [create_test_dataset()]
        output = process.next(process.initialize(), client_model_weights,
                              client_data)
        # Train metrics should be multiplied by two by the custom aggregator.
        self.assertEqual(output.measurements['train']['num_examples'], 16)
Exemplo n.º 6
0
    def test_execution_with_optimizer(self, optimizer):
        client_work_process = model_delta_client_work.build_model_delta_client_work(
            self.create_model,
            optimizer,
            client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES)
        client_data = [create_test_dataset()]
        client_model_weights = [create_test_initial_weights()]

        state = client_work_process.initialize()
        output = client_work_process.next(state, client_model_weights,
                                          client_data)

        self.assertCountEqual(output.measurements.keys(), ['train'])
Exemplo n.º 7
0
    def test_type_properties(self, optimizer, weighting):
        model_fn = model_examples.LinearRegression
        client_work_process = model_delta_client_work.build_model_delta_client_work(
            model_fn, optimizer, weighting)
        self.assertIsInstance(client_work_process,
                              client_works.ClientWorkProcess)

        mw_type = model_utils.ModelWeights(
            trainable=computation_types.to_type([(tf.float32, (2, 1)),
                                                 tf.float32]),
            non_trainable=computation_types.to_type([tf.float32]))
        expected_param_model_weights_type = computation_types.at_clients(
            mw_type)
        expected_param_data_type = computation_types.at_clients(
            computation_types.SequenceType(
                computation_types.to_type(model_fn().input_spec)))
        expected_result_type = computation_types.at_clients(
            client_works.ClientResult(
                update=mw_type.trainable,
                update_weight=computation_types.TensorType(tf.float32)))
        expected_state_type = computation_types.at_server(())
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(train=collections.OrderedDict(
                loss=tf.float32, num_examples=tf.int32)))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            client_work_process.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_model_weights_type,
                client_data=expected_param_data_type),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, expected_result_type,
                expected_measurements_type))
        expected_next_type.check_equivalent_to(
            client_work_process.next.type_signature)
Exemplo n.º 8
0
def build_basic_fedavg_process(model_fn: Callable[[], model_lib.Model],
                               client_learning_rate: float):
    """Builds vanilla Federated Averaging process.

  The created process is the basic form of the Federated Averaging algorithm as
  proposed by http://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf in
  Algorithm 1, for training the model created by `model_fn`. The following is
  the algorithm in pseudo-code:

  ```
  # Inputs: m: Initial model weights; eta: Client learning rate
  for i in num_rounds:
    for c in available_clients_indices:
      delta_m_c, w_c = client_update(m, eta)
    aggregate_model_delta = sum_c(model_delta_c * w_c) / sum_c(w_c)
    m = m - aggregate_model_delta
  return m  # Final trained model.

  def client_udpate(m, eta):
    initial_m = m
    for batch in client_dataset:
      m = m - eta * grad(m, b)
    delta_m = initial_m - m
    return delta_m, size(dataset)
  ```

  The other algorithm hyper parameters (batch size, number of local epochs) are
  controlled by the data provided to the built process.

  An example usage of the returned `LearningProcess` in simulation:

  ```
  fedavg = build_basic_fedavg_process(model_fn, 0.1)

  # Create a `LearningAlgorithmState` containing the initial model weights for
  # the model returned from `model_fn`.
  state = fedavg.initialize()
  for _ in range(num_rounds):
    client_data = ...  # Preprocessed client datasets
    output = fedavg.next(state, client_data)
    write_round_metrics(outpus.metrics)
    # The new state contains the updated model weights after this round.
    state = output.state
  ```

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_learning_rate: A float. Learning rate for the SGD at clients.

  Returns:
    A `LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(client_learning_rate, float)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    distributor = distributors.build_broadcast_process(model_weights_type)
    client_work = model_delta_client_work.build_model_delta_client_work(
        model_fn,
        sgdm.build_sgdm(client_learning_rate),
        client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES)
    aggregator = mean.MeanFactory().create(
        client_work.next.type_signature.result.result.member.update,
        client_work.next.type_signature.result.result.member.update_weight)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0), model_weights_type)

    return compose_learning_process(initial_model_weights_fn, distributor,
                                    client_work, aggregator, finalizer)
Exemplo n.º 9
0
def build_weighted_fed_prox(
    model_fn: Callable[[], model_lib.Model],
    proximal_strength: float,
    client_optimizer_fn: Union[optimizer_base.Optimizer,
                               Callable[[], tf.keras.optimizers.Optimizer]],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    client_weighting: Optional[
        client_weight_lib.ClientWeighting] = client_weight_lib.ClientWeighting.
    NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs the FedProx algorithm.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  example-weighted FedProx on client models. This algorithm behaves the same as
  federated averaging, except that it uses a proximal regularization term that
  encourages clients to not drift too far from the server model.

  The iterative process has the following methods inherited from
  `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize`and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `client_optimizer_fn`. Each client computes the
  difference between the client model after training and the initial model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function, according to `client_weighting`. The aggregate model
  delta is applied at the server using a server optimizer, as in the FedOpt
  framework proposed in [Reddi et al., 2021](https://arxiv.org/abs/2003.00295).

  Note: The default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedProx algorithm in
  [Li et al., 2020](https://arxiv.org/abs/1812.06127). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    proximal_strength: A nonnegative float representing the parameter of
      FedProx's regularization term. When set to `0`, the algorithm reduces to
      FedAvg. Higher values prevent clients from moving too far from the server
      model during local training.
    client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that broadcasts the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.

  Raises:
    ValueError: If `proximal_parameter` is not a nonnegative float.
  """
    if not isinstance(proximal_strength, float) or proximal_strength < 0.0:
        raise ValueError(
            'proximal_strength must be a nonnegative float, found {}'.format(
                proximal_strength))

    py_typecheck.check_callable(model_fn)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)

    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)
    aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))
    process_signature = aggregator.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
        raise TypeError(
            '`model_update_aggregation_factory` does not produce a '
            'compatible `AggregationProcess`. The processes must '
            'retain the type structure of the inputs on the '
            f'server, but got {input_client_value_type.member} != '
            f'{result_server_value_type.member}.')

    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize
    client_work = model_delta_client_work.build_model_delta_client_work(
        model_fn=model_fn,
        optimizer=client_optimizer_fn,
        client_weighting=client_weighting,
        delta_l2_regularizer=proximal_strength,
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              aggregator, finalizer)
Exemplo n.º 10
0
def build_weighted_fed_avg(
    model_fn: Callable[[], model_lib.Model],
    client_optimizer_fn: Union[optimizer_base.Optimizer,
                               Callable[[], tf.keras.optimizers.Optimizer]],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    *,
    client_weighting: Optional[
        client_weight_lib.ClientWeighting] = client_weight_lib.ClientWeighting.
    NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False,
    model: Optional[functional.FunctionalModel] = None,
) -> learning_process.LearningProcess:
    """Builds a learning process that performs federated averaging.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  federated averaging on client models. The iterative process has the following
  methods inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `client_optimizer_fn`. Each client computes the
  difference between the client model after training and its initial model.
  These model deltas are then aggregated at the server using a weighted
  aggregation function, according to `client_weighting`. The aggregate model
  delta is applied at the server using a server optimizer.

  Note: the default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedAvg algorithm in
  [McMahan et al., 2017](https://arxiv.org/abs/1602.05629). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers (this generalized FedAvg algorithm is described in
  [Reddi et al., 2021](https://arxiv.org/abs/2003.00295)).

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error. Cannot
      be used with `model` argument.
    client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via
      `tff.learning.templates.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.
    model: A `tff.learning.models.FunctionalModel` to train. Cannot be used with
      `model_fn` and must be `None` if `model_fn` is not `None`.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    if model is not None and model_fn is not None:
        raise ValueError(
            'Must specify only one of `model` and `model_fn`, both '
            'were not `None`.')
    elif model is None and model_fn is None:
        raise ValueError(
            'Must specify one of `model` and `model_fn`, both were '
            '`None`.')
    if model_fn is not None:
        py_typecheck.check_callable(model_fn)
    else:
        py_typecheck.check_type(model, functional.FunctionalModel)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)

    if model is not None:

        @tensorflow_computation.tf_computation()
        def initial_model_weights_fn():
            return model_utils.ModelWeights(
                tuple(
                    tf.convert_to_tensor(w) for w in model.initial_weights[0]),
                tuple(
                    tf.convert_to_tensor(w) for w in model.initial_weights[1]))
    else:

        @tensorflow_computation.tf_computation()
        def initial_model_weights_fn():
            return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)

    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)

    if model is not None:
        trainable_weights_type, _ = model_weights_type
        model_update_type = trainable_weights_type
    else:
        model_update_type = model_weights_type.trainable
    aggregator = model_aggregator.create(
        model_update_type, computation_types.TensorType(tf.float32))

    process_signature = aggregator.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
        raise TypeError(
            '`model_aggregator` does not produce a compatible '
            '`AggregationProcess`. The processes must retain the type '
            'structure of the inputs on the server, but got '
            f'{input_client_value_type.member} != '
            f'{result_server_value_type.member}.')

    if metrics_aggregator is None:
        metrics_aggregator = metric_aggregator.sum_then_finalize

    if model is not None:
        client_work = model_delta_client_work.build_functional_model_delta_client_work(
            model=model,
            optimizer=client_optimizer_fn,
            client_weighting=client_weighting,
            metrics_aggregator=metrics_aggregator)
    else:
        client_work = model_delta_client_work.build_model_delta_client_work(
            model_fn=model_fn,
            optimizer=client_optimizer_fn,
            client_weighting=client_weighting,
            metrics_aggregator=metrics_aggregator,
            use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer_fn, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              aggregator, finalizer)