Example #1
0
def robust_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
) -> factory.AggregationFactory:
    """Creates aggregator for mean with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and clips in the L2 norm to moderately high norm for robustness to
  outliers.

  For details on clipping and zeroing see `tff.aggregators.clipping_factory`
  and `tff.aggregators.zeroing_factory`. For details on the quantile-based
  adaptive algorithm see `tff.aggregators.PrivateQuantileEstimationProcess`.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
    weighted: Whether the mean is weighted (vs. unweighted).

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    factory_ = mean.MeanFactory() if weighted else mean.UnweightedMeanFactory()

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
    def test_add_measurements_to_unweighted_aggregation_factory_output(self):
        mean_factory = mean.UnweightedMeanFactory()
        debug_mean_factory = debug_measurements.add_debug_measurements(
            mean_factory)
        value_type = computation_types.TensorType(tf.float32)
        mean_aggregator = mean_factory.create(value_type)
        debug_aggregator = debug_mean_factory.create(value_type)

        state = mean_aggregator.initialize()
        mean_output = mean_aggregator.next(state, [2.0, 4.0])
        debug_output = debug_aggregator.next(state, [2.0, 4.0])
        self.assertEqual(mean_output.state, debug_output.state)
        self.assertNear(mean_output.result, debug_output.result, err=1e-6)

        mean_measurements = mean_output.measurements
        expected_debugging_measurements = {
            'average_client_norm': 3.0,
            'std_dev_client_norm': tf.math.sqrt(2.0),
            'server_update_max': 3.0,
            'server_update_norm': 3.0,
            'server_update_min': 3.0,
        }
        debugging_measurements = debug_output.measurements
        self.assertCountEqual(
            list(debugging_measurements.keys()),
            list(mean_measurements.keys()) +
            list(expected_debugging_measurements.keys()))
        for k in mean_output.measurements:
            self.assertEqual(mean_measurements[k], debugging_measurements[k])
        for k in expected_debugging_measurements:
            self.assertNear(debugging_measurements[k],
                            expected_debugging_measurements[k],
                            err=1e-6)
Example #3
0
  def test_type_properties_unweighted(self, value_type):
    value_type = computation_types.to_type(value_type)

    factory_ = mean.UnweightedMeanFactory()
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(value_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(())
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
Example #4
0
  def test_type_properties(self, value_type):
    factory = stochastic_discretization.StochasticDiscretizationFactory(
        step_size=0.1,
        inner_agg_factory=_measurement_aggregator,
        distortion_aggregation_factory=mean.UnweightedMeanFactory())
    value_type = computation_types.to_type(value_type)
    quantize_type = type_conversions.structure_from_tensor_type_tree(
        lambda x: (tf.int32, x.shape), value_type)
    process = factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.StructType([('step_size', tf.float32),
                                                      ('inner_agg_process', ())
                                                     ])
    server_state_type = computation_types.at_server(server_state_type)
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    type_test_utils.assert_types_equivalent(process.initialize.type_signature,
                                            expected_initialize_type)

    expected_measurements_type = computation_types.StructType([
        ('stochastic_discretization', quantize_type), ('distortion', tf.float32)
    ])
    expected_measurements_type = computation_types.at_server(
        expected_measurements_type)
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    type_test_utils.assert_types_equivalent(process.next.type_signature,
                                            expected_next_type)
Example #5
0
 def test_unweighted_full_gradient_aggregator_raises(self):
     with self.assertRaises(TypeError):
         mime._build_mime_lite_client_work(
             model_examples.LinearRegression(),
             sgdm.build_sgdm(1.0),
             client_weighting=client_weight_lib.ClientWeighting.
             NUM_EXAMPLES,
             full_gradient_aggregator=mean.UnweightedMeanFactory())
def secure_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
) -> factory.AggregationFactory:
    """Creates secure aggregator with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are summed using cryptographic protocol
  ensuring that the server cannot see individual updates until sufficient number
  of updates have been added together. For details, see Bonawitz et al. (2017)
  https://dl.acm.org/doi/abs/10.1145/3133956.3133982. In TFF, this is realized
  using the `tff.federated_secure_sum_bitwidth` operator.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
      Note this clipping is performed prior to the per-coordinate clipping
      required for secure aggregation.
    weighted: Whether the mean is weighted (vs. unweighted).

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    secure_clip_bound = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=50.0,
        target_quantile=0.95,
        learning_rate=1.0,
        multiplier=2.0,
        secure_estimation=True)

    factory_ = secure.SecureSumFactory(secure_clip_bound)

    if weighted:
        factory_ = mean.MeanFactory(
            value_sum_factory=factory_,
            # Use a power of 2 minus one to more accurately encode floating dtypes
            # that actually contain integer values. 2 ^ 20 gives us approximately a
            # range of [0, 1 million]. Existing use cases have the weights either
            # all ones, or a variant of number of examples processed locally.
            weight_sum_factory=secure.SecureSumFactory(
                upper_bound_threshold=float(2**20 - 1),
                lower_bound_threshold=0.0))
    else:
        factory_ = mean.UnweightedMeanFactory(
            value_sum_factory=factory_,
            count_sum_factory=secure.SecureSumFactory(upper_bound_threshold=1,
                                                      lower_bound_threshold=0))

    if clipping:
        factory_ = _default_clipping(factory_, secure_estimation=True)

    if zeroing:
        factory_ = _default_zeroing(factory_, secure_estimation=True)

    return factory_
def compression_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
    debug_measurements_fn: Optional[Callable[
        factory.AggregationFactory, factory.AggregationFactory]] = None,
    **kwargs,
) -> factory.AggregationFactory:
    """Creates aggregator with compression and adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients and clips in the L2 norm to moderately high norm for robustness to
  outliers. After weighting in mean, the weighted values are uniformly quantized
  to reduce the size of the model update communicated from clients to the
  server. For details, see Suresh et al. (2017)
  http://proceedings.mlr.press/v70/suresh17a/suresh17a.pdf. The default
  configuration is chosen such that compression does not have adverse effect on
  trained model quality in typical tasks.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
      Note this clipping is performed prior to the per-coordinate clipping
      required for quantization.
    weighted: Whether the mean is weighted (vs. unweighted).
    debug_measurements_fn: A callable to add measurements suitable for debugging
      learning algorithms, with possible values as None,
      `tff.learning.add_debug_measurements` or
      `tff.learning.add_debug_measurements_with_mixed_dtype`.
    **kwargs: Keyword arguments.

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    factory_ = encoded.EncodedSumFactory.quantize_above_threshold(
        quantization_bits=8, threshold=20000, **kwargs)

    factory_ = (mean.MeanFactory(factory_)
                if weighted else mean.UnweightedMeanFactory(factory_))

    if debug_measurements_fn:
        factory_ = debug_measurements_fn(factory_)
        if (weighted and not isinstance(
                factory_, factory.WeightedAggregationFactory)) or (
                    (not weighted) and (not isinstance(
                        factory_, factory.UnweightedAggregationFactory))):
            raise TypeError(
                'debug_measurements_fn should return the same type.')

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
def ddp_secure_aggregator(
        noise_multiplier: float,
        expected_clients_per_round: int,
        bits: int = 20,
        zeroing: bool = True,
        rotation_type: str = 'hd') -> factory.UnweightedAggregationFactory:
    """Creates aggregator with adaptive zeroing and distributed DP.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and performs distributed DP (compression, discrete noising, and
  SecAgg) with adaptive clipping for differentially private learning. For
  details of the two main distributed DP algorithms see
  https://arxiv.org/pdf/2102.06387
  or https://arxiv.org/pdf/2110.04995.pdf. The adaptive clipping uses the
  geometric method described in https://arxiv.org/abs/1905.03871.

  Args:
    noise_multiplier: A float specifying the noise multiplier (with respect to
      the initial L2 cipping) for the distributed DP mechanism for model
      updates. A value of 1.0 or higher may be needed for meaningful privacy.
    expected_clients_per_round: An integer specifying the expected number of
      clients per round. Must be positive.
    bits: An integer specifying the bit-width for the aggregation. Note that
      this is for the noisy, quantized aggregate at the server and thus should
      account for the `expected_clients_per_round`. Must be in the inclusive
      range of [1, 22]. This is set to 20 bits by default, and it dictates the
      computational and communication efficiency of Secure Aggregation. Setting
      it to less than 20 bits should work fine for most cases. For instance, for
      an expected number of securely aggregated client updates of 100, 12 bits
      should be enough, and for an expected number of securely aggregated client
      updates of 1000, 16 bits should be enough.
    zeroing: A bool indicating whether to enable adaptive zeroing for data
      corruption mitigation. Defaults to `True`.
    rotation_type: A string indicating what rotation to use for distributed DP.
      Valid options are 'hd' (Hadamard transform) and 'dft' (discrete Fourier
      transform). Defaults to `hd`.

  Returns:
    A `tff.aggregators.UnweightedAggregationFactory`.
  """
    agg_factory = distributed_dp.DistributedDpSumFactory(
        noise_multiplier=noise_multiplier,
        expected_clients_per_round=expected_clients_per_round,
        bits=bits,
        l2_clip=0.1,
        mechanism='distributed_skellam',
        rotation_type=rotation_type,
        auto_l2_clip=True)
    agg_factory = mean.UnweightedMeanFactory(
        value_sum_factory=agg_factory,
        count_sum_factory=secure.SecureSumFactory(upper_bound_threshold=1,
                                                  lower_bound_threshold=0))

    if zeroing:
        agg_factory = _default_zeroing(agg_factory, secure_estimation=True)

    return agg_factory
Example #9
0
  def test_incorrect_create_type_raises(self, wrong_type):
    factory_ = mean.MeanFactory()
    correct_type = computation_types.to_type(tf.float32)
    with self.assertRaises(TypeError):
      factory_.create(wrong_type, correct_type)
    with self.assertRaises(TypeError):
      factory_.create(correct_type, wrong_type)

    factory_ = mean.UnweightedMeanFactory()
    with self.assertRaises(TypeError):
      factory_.create(wrong_type)
 def test_add_measurements_to_unweighted_aggregation_factory_types(self):
     mean_factory = mean.UnweightedMeanFactory()
     debug_mean_factory = debug_measurements.add_debug_measurements(
         mean_factory)
     value_type = computation_types.TensorType(tf.float32)
     mean_aggregator = mean_factory.create(value_type)
     debug_aggregator = debug_mean_factory.create(value_type)
     self.assertFalse(debug_aggregator.is_weighted)
     self.assertEqual(mean_aggregator.initialize.type_signature,
                      debug_aggregator.initialize.type_signature)
     self.assertEqual(mean_aggregator.next.type_signature.parameter,
                      debug_aggregator.next.type_signature.parameter)
     self.assertEqual(mean_aggregator.next.type_signature.result.state,
                      debug_aggregator.next.type_signature.result.state)
     self.assertEqual(mean_aggregator.next.type_signature.result.result,
                      debug_aggregator.next.type_signature.result.result)
Example #11
0
  def test_structure_value_unweighted(self):
    factory_ = mean.UnweightedMeanFactory()
    value_type = computation_types.to_type(_test_struct_type)
    process = factory_.create(value_type)
    expected_state = ()
    expected_measurements = collections.OrderedDict(mean_value=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    output = process.next(state, client_data)

    self.assertAllEqual(expected_state, output.state)
    self.assertAllClose(((2.0, 7 / 3), 4.0), output.result)
    self.assertEqual(expected_measurements, output.measurements)
Example #12
0
  def test_inner_value_sum_factory_unweighted(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean.UnweightedMeanFactory(value_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type)

    state = process.initialize()
    self.assertAllEqual(0, state)

    # Values will be summed to 7.0.
    client_data = [1.0, 2.0, 3.0]

    output = process.next(state, client_data)
    self.assertAllEqual(1, output.state)
    self.assertAllClose(7 / 3, output.result)
    self.assertEqual(
        collections.OrderedDict(mean_value=M_CONST), output.measurements)
Example #13
0
  def test_scalar_value_unweighted(self):
    factory_ = mean.UnweightedMeanFactory()
    value_type = computation_types.to_type(tf.float32)

    process = factory_.create(value_type)
    expected_state = ()
    expected_measurements = collections.OrderedDict(mean_value=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [1.0, 2.0, 3.0]
    output = process.next(state, client_data)
    self.assertAllClose(2.0, output.result)

    self.assertAllEqual(expected_state, output.state)
    self.assertEqual(expected_measurements, output.measurements)
Example #14
0
def secure_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
) -> factory.AggregationFactory:
    """Creates secure aggregator with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are summed using cryptographic protocol
  ensuring that the server cannot see individual updates until sufficient number
  of updates have been added together. For details, see Bonawitz et al. (2017)
  https://dl.acm.org/doi/abs/10.1145/3133956.3133982. In TFF, this is realized
  using the `tff.federated_secure_sum` operator.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
      Note this clipping is performed prior to the per-coordinate clipping
      required for secure aggregation.
    weighted: Whether the mean is weighted (vs. unweighted).

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    secure_clip_bound = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=50.0,
        target_quantile=0.95,
        learning_rate=1.0,
        multiplier=2.0)

    factory_ = secure.SecureSumFactory(secure_clip_bound)

    factory_ = (mean.MeanFactory(factory_)
                if weighted else mean.UnweightedMeanFactory(factory_))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
def robust_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
    debug_measurements_fn: Optional[Callable[
        [factory.AggregationFactory], factory.AggregationFactory]] = None,
) -> factory.AggregationFactory:
    """Creates aggregator for mean with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and clips in the L2 norm to moderately high norm for robustness to
  outliers.

  For details on clipping and zeroing see `tff.aggregators.clipping_factory`
  and `tff.aggregators.zeroing_factory`. For details on the quantile-based
  adaptive algorithm see `tff.aggregators.PrivateQuantileEstimationProcess`.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
    weighted: Whether the mean is weighted (vs. unweighted).
    debug_measurements_fn: A callable to add measurements suitable for debugging
      learning algorithms. Often useful values include None,
      `tff.learning.add_debug_measurements` or
      `tff.learning.add_debug_measurements_with_mixed_dtype`.

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    factory_ = mean.MeanFactory() if weighted else mean.UnweightedMeanFactory()

    if debug_measurements_fn:
        factory_ = debug_measurements_fn(factory_)

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
Example #16
0
  def test_discretize_impl(self, value_type, client_values, expected_sum):
    factory = stochastic_discretization.StochasticDiscretizationFactory(
        inner_agg_factory=_measurement_aggregator,
        step_size=0.1,
        distortion_aggregation_factory=mean.UnweightedMeanFactory())
    value_type = computation_types.to_type(value_type)
    process = factory.create(value_type)
    state = process.initialize()

    expected_result = expected_sum
    expected_quantized_result = tf.nest.map_structure(lambda x: x * 10,
                                                      expected_sum)
    expected_measurements = collections.OrderedDict(
        stochastic_discretization=expected_quantized_result, distortion=0.)

    for _ in range(3):
      output = process.next(state, client_values)
      output_measurements = output.measurements
      self.assertAllClose(output_measurements, expected_measurements)
      result = output.result
      self.assertAllClose(result, expected_result)
Example #17
0
def compression_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
) -> factory.AggregationFactory:
    """Creates aggregator with compression and adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients and clips in the L2 norm to moderately high norm for robustness to
  outliers. After weighting in mean, the weighted values are uniformly quantized
  to reduce the size of the model update communicated from clients to the
  server. For details, see Suresh et al. (2017)
  http://proceedings.mlr.press/v70/suresh17a/suresh17a.pdf. The default
  configuration is chosen such that compression does not have adverse effect on
  trained model quality in typical tasks.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
      Note this clipping is performed prior to the per-coordinate clipping
      required for quantization.
    weighted: Whether the mean is weighted (vs. unweighted).

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    factory_ = encoded.EncodedSumFactory.quantize_above_threshold(
        quantization_bits=8, threshold=20000)

    factory_ = (mean.MeanFactory(factory_)
                if weighted else mean.UnweightedMeanFactory(factory_))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
Example #18
0
def build_unweighted_mime_lite(
    model_fn: Callable[[], model_lib.Model],
    base_optimizer: optimizer_base.Optimizer,
    server_optimizer: optimizer_base.Optimizer = sgdm.build_sgdm(1.0),
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.UnweightedAggregationFactory] = None,
    full_gradient_aggregator: Optional[
        factory.UnweightedAggregationFactory] = None,
    metrics_aggregator: Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation] = metric_aggregator.sum_then_finalize,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs Mime Lite.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  Mime Lite algorithm on client models. The iterative process has the following
  methods inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `optimizer`, where its state is communicated by
  the server, and kept intact during local training. The state is updated only
  at the server based on the full gradient evaluated by the clients based on the
  current server model state. The client full gradients are aggregated by
  unweighted `full_gradient_aggregator`. Each client computes the difference
  between the client model after training and its initial model. These model
  deltas are then aggregated by unweighted `model_aggregator`. The aggregate
  model delta is added to the existing server model state.

  The Mime Lite algorithm is based on the paper
  "Breaking the centralized barrier for cross-device federated learning."
    Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank
    Reddi, Sebastian U. Stich, and Ananda Theertha Suresh.
    Advances in Neural Information Processing Systems 34 (2021).
    https://proceedings.neurips.cc/paper/2021/file/f0e6be4ce76ccfa73c5a540d992d0756-Paper.pdf

  Note that Keras optimizers are not supported. This is due to the Mime Lite
  algorithm applying the optimizer without changing it state at clients
  (optimizer's `tf.Variable`s in the case of Keras), which is not possible with
  Keras optimizers without reaching into private implementation details and
  incurring additional computation and memory cost at clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    base_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for
      both creating and updating a global optimizer state, as well as
      optimization at clients given the global state, which is fixed during the
      optimization.
    server_optimizer: A `tff.learning.optimizers.Optimizer` which will be used
      for applying the aggregate model update to the global model weights.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.UnweightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.UnweightedMeanFactory`.
    full_gradient_aggregator: An optional
      `tff.aggregators.UnweightedAggregationFactory` used to aggregate the full
      gradients on client datasets. If `None`, this is set to
      `tff.aggregators.UnweightedMeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    if model_aggregator is None:
        model_aggregator = mean.UnweightedMeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.UnweightedAggregationFactory)
    if full_gradient_aggregator is None:
        full_gradient_aggregator = mean.UnweightedMeanFactory()
    py_typecheck.check_type(full_gradient_aggregator,
                            factory.UnweightedAggregationFactory)

    return build_weighted_mime_lite(
        model_fn=model_fn,
        base_optimizer=base_optimizer,
        server_optimizer=server_optimizer,
        client_weighting=client_weight_lib.ClientWeighting.UNIFORM,
        model_distributor=model_distributor,
        model_aggregator=factory_utils.as_weighted_aggregator(
            model_aggregator),
        full_gradient_aggregator=factory_utils.as_weighted_aggregator(
            full_gradient_aggregator),
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
Example #19
0
def _hadamard_mean():
  return rotation.HadamardTransformFactory(mean.UnweightedMeanFactory())
Example #20
0
def _dft_mean():
  return rotation.DiscreteFourierTransformFactory(mean.UnweightedMeanFactory())
Example #21
0
def build_unweighted_fed_prox(
    model_fn: Callable[[], model_lib.Model],
    proximal_strength: float,
    client_optimizer_fn: Union[optimizer_base.Optimizer,
                               Callable[[], tf.keras.optimizers.Optimizer]],
    server_optimizer_fn: Union[optimizer_base.Optimizer, Callable[
        [], tf.keras.optimizers.Optimizer]] = DEFAULT_SERVER_OPTIMIZER_FN,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.UnweightedAggregationFactory] = None,
    metrics_aggregator: Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation] = metric_aggregator.sum_then_finalize,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs the FedProx algorithm.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  example-weighted FedProx on client models. This algorithm behaves the same as
  federated averaging, except that it uses a proximal regularization term that
  encourages clients to not drift too far from the server model.

  The iterative process has the following methods inherited from
  `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `client_optimizer_fn`. Each client computes the
  difference between the client model after training and the initial model.
  These model deltas are then aggregated at the server using an unweighted
  aggregation function. The aggregate model delta is applied at the server using
  a server optimizer, as in the FedOpt framework proposed in
  [Reddi et al., 2021](https://arxiv.org/abs/2003.00295).

  Note: The default server optimizer function is `tf.keras.optimizers.SGD`
  with a learning rate of 1.0, which corresponds to adding the model delta to
  the current server model. This recovers the original FedProx algorithm in
  [Li et al., 2020](https://arxiv.org/abs/1812.06127). More
  sophisticated federated averaging procedures may use different learning rates
  or server optimizers.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    proximal_strength: A nonnegative float representing the parameter of
      FedProx's regularization term. When set to `0`, the algorithm reduces to
      FedAvg. Higher values prevent clients from moving too far from the server
      model during local training.
    client_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`.
    server_optimizer_fn: A `tff.learning.optimizers.Optimizer`, or a no-arg
      callable that returns a `tf.keras.Optimizer`. By default, this uses
      `tf.keras.optimizers.SGD` with a learning rate of 1.0.
    model_distributor: An optional `DistributionProcess` that broadcasts the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.UnweightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.UnweightedMeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.

  Raises:
    ValueError: If `proximal_parameter` is not a nonnegative float.
  """
    if model_aggregator is None:
        model_aggregator = mean.UnweightedMeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.UnweightedAggregationFactory)

    return build_weighted_fed_prox(
        model_fn=model_fn,
        proximal_strength=proximal_strength,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weighting=client_weight_lib.ClientWeighting.UNIFORM,
        model_distributor=model_distributor,
        model_aggregator=factory_utils.as_weighted_aggregator(
            model_aggregator),
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)