예제 #1
0
def _instantiate_aggregation_process(
        aggregation_factory,
        model_weights_type) -> aggregation_process_lib.AggregationProcess:
    """Constructs aggregation process given factory, checking compatibilty."""
    if aggregation_factory is None:
        aggregation_factory = mean.MeanFactory()
    py_typecheck.check_type(aggregation_factory,
                            factory.AggregationFactory.__args__)

    # We give precedence to unweighted aggregation.
    if isinstance(aggregation_factory, factory.UnweightedAggregationFactory):
        aggregation_process = aggregation_factory.create(
            model_weights_type.trainable)
    elif isinstance(aggregation_factory, factory.WeightedAggregationFactory):
        aggregation_process = aggregation_factory.create(
            model_weights_type.trainable,
            computation_types.TensorType(tf.float32))
    else:
        raise ValueError('Unknown type of aggregation factory: {}'.format(
            type(aggregation_factory)))

    process_signature = aggregation_process.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
        raise TypeError('`aggregation_factory` does not produce a '
                        'compatible `AggregationProcess`. The processes must '
                        'retain the type structure of the inputs on the '
                        f'server, but got {input_client_value_type.member} != '
                        f'{result_server_value_type.member}.')

    return aggregation_process
예제 #2
0
  def test_type_properties(self, value_type, weight_type):
    value_type = computation_types.to_type(value_type)
    weight_type = computation_types.to_type(weight_type)

    factory_ = mean.MeanFactory()
    self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
    process = factory_.create(value_type, weight_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()))
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=(), mean_weight=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_parameter = collections.OrderedDict(
        state=expected_state_type,
        value=param_value_type,
        weight=computation_types.at_clients(weight_type))

    expected_next_type = computation_types.FunctionType(
        parameter=expected_parameter,
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
예제 #3
0
  def test_inner_value_and_weight_sum_factory(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean.MeanFactory(
        value_sum_factory=sum_factory, weight_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)

    state = process.initialize()
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=0, weight_sum_process=0),
        state)

    # Weighted values will be summed to 11.0 and weights will be summed to 7.0.
    client_data = [1.0, 2.0, 3.0]
    weights = [3.0, 2.0, 1.0]

    output = process.next(state, client_data, weights)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=1, weight_sum_process=1),
        output.state)
    self.assertAllClose(11 / 7, output.result)
    self.assertEqual(
        collections.OrderedDict(mean_value=M_CONST, mean_weight=M_CONST),
        output.measurements)
예제 #4
0
def secure_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates secure aggregator with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are summed using cryptographic protocol
  ensuring that the server cannot see individual updates until sufficient number
  of updates have been added together. For details, see Bonawitz et al. (2017)
  https://dl.acm.org/doi/abs/10.1145/3133956.3133982. In TFF, this is realized
  using the `tff.federated_secure_sum` operator.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    secure_clip_bound = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=50.0,
        target_quantile=0.95,
        learning_rate=1.0,
        multiplier=2.0)
    factory_ = mean.MeanFactory(secure.SecureSumFactory(secure_clip_bound))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
    def test_add_measurements_to_weighted_aggregation_factory_output(self):
        mean_factory = mean.MeanFactory()
        debug_mean_factory = debug_measurements.add_debug_measurements(
            mean_factory)
        value_type = computation_types.TensorType(tf.float32)
        mean_aggregator = mean_factory.create(value_type, value_type)
        debug_aggregator = debug_mean_factory.create(value_type, value_type)

        state = mean_aggregator.initialize()
        mean_output = mean_aggregator.next(state, [2.0, 4.0], [1.0, 1.0])
        debug_output = debug_aggregator.next(state, [2.0, 4.0], [1.0, 1.0])
        self.assertEqual(mean_output.state, debug_output.state)
        self.assertNear(mean_output.result, debug_output.result, err=1e-6)

        mean_measurements = mean_output.measurements
        expected_debugging_measurements = {
            'average_client_norm': 3.0,
            'std_dev_client_norm': tf.math.sqrt(2.0),
            'server_update_max': 3.0,
            'server_update_norm': 3.0,
            'server_update_min': 3.0,
        }
        debugging_measurements = debug_output.measurements
        self.assertCountEqual(
            list(debugging_measurements.keys()),
            list(mean_measurements.keys()) +
            list(expected_debugging_measurements.keys()))
        for k in mean_output.measurements:
            self.assertEqual(mean_measurements[k], debugging_measurements[k])
        for k in expected_debugging_measurements:
            self.assertNear(debugging_measurements[k],
                            expected_debugging_measurements[k],
                            err=1e-6)
예제 #6
0
def robust_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates aggregator for mean with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and clips to moderately high norm for robustness to outliers.

  For details on clipping and zeroing see `tff.aggregators.clipping_factory`
  and `tff.aggregators.zeroing_factory`. For details on the quantile-based
  adaptive algorithm see `tff.aggregators.PrivateQuantileEstimationProcess`.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    factory_ = mean.MeanFactory()

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
예제 #7
0
def compression_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates aggregator with compression and adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are uniformly quantized to reduce the
  size of the model update communicated from clients to the server. For details,
  see Suresh et al. (2017)
  http://proceedings.mlr.press/v70/suresh17a/suresh17a.pdf. The default
  configuration is chosen such that compression does not have adverse effect on
  trained model quality in typical tasks.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    factory_ = mean.MeanFactory(
        encoded.EncodedSumFactory.quantize_above_threshold(quantization_bits=8,
                                                           threshold=20000))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
    def test_construction(self, weighted):
        aggregation_factory = (mean.MeanFactory()
                               if weighted else sum_factory.SumFactory())
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.LinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            model_update_aggregation_factory=aggregation_factory)

        if weighted:
            aggregate_state = collections.OrderedDict(value_sum_process=(),
                                                      weight_sum_process=())
            aggregate_metrics = collections.OrderedDict(mean_value=(),
                                                        mean_weight=())
        else:
            aggregate_state = ()
            aggregate_metrics = ()

        server_state_type = computation_types.FederatedType(
            optimizer_utils.ServerState(model=model_utils.ModelWeights(
                trainable=[
                    computation_types.TensorType(tf.float32, [2, 1]),
                    computation_types.TensorType(tf.float32)
                ],
                non_trainable=[computation_types.TensorType(tf.float32)]),
                                        optimizer_state=[tf.int64],
                                        delta_aggregate_state=aggregate_state,
                                        model_broadcast_state=()),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=None,
                                           result=server_state_type),
            iterative_process.initialize.type_signature)

        dataset_type = computation_types.FederatedType(
            computation_types.SequenceType(
                collections.OrderedDict(
                    x=computation_types.TensorType(tf.float32, [None, 2]),
                    y=computation_types.TensorType(tf.float32, [None, 1]))),
            placements.CLIENTS)
        metrics_type = computation_types.FederatedType(
            collections.OrderedDict(
                broadcast=(),
                aggregation=aggregate_metrics,
                train=collections.OrderedDict(
                    loss=computation_types.TensorType(tf.float32),
                    num_examples=computation_types.TensorType(tf.int32)),
                stat=collections.OrderedDict(
                    num_examples=computation_types.TensorType(tf.float32))),
            placements.SERVER)
        self.assert_types_equivalent(
            computation_types.FunctionType(parameter=collections.OrderedDict(
                server_state=server_state_type,
                federated_dataset=dataset_type,
            ),
                                           result=(server_state_type,
                                                   metrics_type)),
            iterative_process.next.type_signature)
예제 #9
0
    def test_custom_model_zeroing_clipping_aggregator_factory(self):
        client_data = create_emnist_client_data()
        train_data = [client_data(), client_data()]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                counters.NumExamplesCounter(),
                counters.NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        # No values should be clipped and zeroed
        aggregation_factory = robust.zeroing_factory(
            zeroing_norm=float('inf'), inner_agg_factory=mean.MeanFactory())

        # Disable reconstruction via 0 learning rate to ensure post-recon loss
        # matches exact expectations round 0 and decreases by the next round.
        trainer = training_process.build_training_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            server_optimizer_fn=_get_keras_optimizer_fn(0.01),
            client_optimizer_fn=_get_keras_optimizer_fn(0.001),
            reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0),
            aggregation_factory=aggregation_factory,
            dataset_split_fn=reconstruction_utils.simple_dataset_split_fn)
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        # All weights and biases are initialized to 0, so initial logits are all 0
        # and softmax probabilities are uniform over 10 classes. So negative log
        # likelihood is -ln(1/10). This is on expectation, so increase tolerance.
        self.assertAllClose(outputs[0]['train']['loss'],
                            tf.math.log(10.0),
                            rtol=1e-4)
        self.assertLess(outputs[1]['train']['loss'],
                        outputs[0]['train']['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        # Expect 6 reconstruction examples, 6 training examples. Only training
        # included in metrics.
        self.assertEqual(outputs[0]['train']['num_examples'], 6.0)
        self.assertEqual(outputs[1]['train']['num_examples'], 6.0)

        # Expect 4 reconstruction batches and 4 training batches. Only training
        # included in metrics.
        self.assertEqual(outputs[0]['train']['num_batches'], 4.0)
        self.assertEqual(outputs[1]['train']['num_batches'], 4.0)
def secure_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
) -> factory.AggregationFactory:
    """Creates secure aggregator with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are summed using cryptographic protocol
  ensuring that the server cannot see individual updates until sufficient number
  of updates have been added together. For details, see Bonawitz et al. (2017)
  https://dl.acm.org/doi/abs/10.1145/3133956.3133982. In TFF, this is realized
  using the `tff.federated_secure_sum_bitwidth` operator.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
      Note this clipping is performed prior to the per-coordinate clipping
      required for secure aggregation.
    weighted: Whether the mean is weighted (vs. unweighted).

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    secure_clip_bound = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=50.0,
        target_quantile=0.95,
        learning_rate=1.0,
        multiplier=2.0,
        secure_estimation=True)

    factory_ = secure.SecureSumFactory(secure_clip_bound)

    if weighted:
        factory_ = mean.MeanFactory(
            value_sum_factory=factory_,
            # Use a power of 2 minus one to more accurately encode floating dtypes
            # that actually contain integer values. 2 ^ 20 gives us approximately a
            # range of [0, 1 million]. Existing use cases have the weights either
            # all ones, or a variant of number of examples processed locally.
            weight_sum_factory=secure.SecureSumFactory(
                upper_bound_threshold=float(2**20 - 1),
                lower_bound_threshold=0.0))
    else:
        factory_ = mean.UnweightedMeanFactory(
            value_sum_factory=factory_,
            count_sum_factory=secure.SecureSumFactory(upper_bound_threshold=1,
                                                      lower_bound_threshold=0))

    if clipping:
        factory_ = _default_clipping(factory_, secure_estimation=True)

    if zeroing:
        factory_ = _default_zeroing(factory_, secure_estimation=True)

    return factory_
def compression_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
    debug_measurements_fn: Optional[Callable[
        factory.AggregationFactory, factory.AggregationFactory]] = None,
    **kwargs,
) -> factory.AggregationFactory:
    """Creates aggregator with compression and adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients and clips in the L2 norm to moderately high norm for robustness to
  outliers. After weighting in mean, the weighted values are uniformly quantized
  to reduce the size of the model update communicated from clients to the
  server. For details, see Suresh et al. (2017)
  http://proceedings.mlr.press/v70/suresh17a/suresh17a.pdf. The default
  configuration is chosen such that compression does not have adverse effect on
  trained model quality in typical tasks.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
      Note this clipping is performed prior to the per-coordinate clipping
      required for quantization.
    weighted: Whether the mean is weighted (vs. unweighted).
    debug_measurements_fn: A callable to add measurements suitable for debugging
      learning algorithms, with possible values as None,
      `tff.learning.add_debug_measurements` or
      `tff.learning.add_debug_measurements_with_mixed_dtype`.
    **kwargs: Keyword arguments.

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    factory_ = encoded.EncodedSumFactory.quantize_above_threshold(
        quantization_bits=8, threshold=20000, **kwargs)

    factory_ = (mean.MeanFactory(factory_)
                if weighted else mean.UnweightedMeanFactory(factory_))

    if debug_measurements_fn:
        factory_ = debug_measurements_fn(factory_)
        if (weighted and not isinstance(
                factory_, factory.WeightedAggregationFactory)) or (
                    (not weighted) and (not isinstance(
                        factory_, factory.UnweightedAggregationFactory))):
            raise TypeError(
                'debug_measurements_fn should return the same type.')

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
예제 #12
0
    def test_raises_bad_measurement_fn(self):
        unweighted_factory = sum_factory.SumFactory()
        with self.assertRaisesRegex(ValueError, 'single parameter'):
            measurements.add_measurements(unweighted_factory,
                                          _get_weighted_min)

        weighted_factory = mean.MeanFactory()
        with self.assertRaisesRegex(ValueError, 'two parameters'):
            measurements.add_measurements(weighted_factory, _get_min)
예제 #13
0
  def test_incorrect_create_type_raises(self, wrong_type):
    factory_ = mean.MeanFactory()
    correct_type = computation_types.to_type(tf.float32)
    with self.assertRaises(TypeError):
      factory_.create(wrong_type, correct_type)
    with self.assertRaises(TypeError):
      factory_.create(correct_type, wrong_type)

    factory_ = mean.UnweightedMeanFactory()
    with self.assertRaises(TypeError):
      factory_.create(wrong_type)
예제 #14
0
  def test_weight_arg_all_zeros_no_nan_division(self):
    factory_ = mean.MeanFactory(no_nan_division=True)
    value_type = computation_types.to_type(tf.float32)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)

    state = process.initialize()
    client_data = [1.0, 2.0, 3.0]
    weights = [0.0, 0.0, 0.0]

    # Division by zero resulting in NaN/Inf *should not* occur.
    self.assertEqual(0.0, process.next(state, client_data, weights).result)
예제 #15
0
 def test_aggregation_process_deprecation(self):
   aggregation_process = mean.MeanFactory().create(
       computation_types.to_type([(tf.float32, (2, 1)), tf.float32]),
       computation_types.TensorType(tf.float32))
   with warnings.catch_warnings(record=True) as w:
     warnings.simplefilter('always')
     federated_sgd.build_federated_sgd_process(
         model_fn=model_examples.LinearRegression,
         aggregation_process=aggregation_process)
     self.assertNotEmpty(w)
     self.assertEqual(w[0].category, DeprecationWarning)
     self.assertRegex(
         str(w[0].message), 'aggregation_process .* is deprecated')
예제 #16
0
  def test_weight_arg(self):
    factory_ = mean.MeanFactory()
    value_type = computation_types.to_type(tf.float32)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)

    state = process.initialize()
    client_data = [1.0, 2.0, 3.0]
    weights = [1.0, 1.0, 1.0]
    self.assertEqual(2.0, process.next(state, client_data, weights).result)
    weights = [0.1, 0.1, 0.1]
    self.assertEqual(2.0, process.next(state, client_data, weights).result)
    weights = [6.0, 3.0, 1.0]
    self.assertEqual(1.5, process.next(state, client_data, weights).result)
예제 #17
0
    def test_weighted(self):
        factory = mean.MeanFactory()
        factory = measurements.add_measurements(factory, _get_weighted_min)
        process = factory.create(_float_type, _float_type)

        state = process.initialize()
        client_values = [1.0, 2.0, 3.0]
        client_weights = [3.0, 1.0, 2.0]
        output = process.next(state, client_values, client_weights)
        self.assertAllClose(11 / 6, output.result)
        self.assertDictEqual(
            collections.OrderedDict(mean_value=(),
                                    mean_weight=(),
                                    min_weighted_value=2.0),
            output.measurements)
 def test_add_measurements_to_weighted_aggregation_factory_types(self):
     mean_factory = mean.MeanFactory()
     debug_mean_factory = debug_measurements.add_debug_measurements(
         mean_factory)
     value_type = computation_types.TensorType(tf.float32)
     mean_aggregator = mean_factory.create(value_type, value_type)
     debug_aggregator = debug_mean_factory.create(value_type, value_type)
     self.assertTrue(debug_aggregator.is_weighted)
     self.assertEqual(mean_aggregator.initialize.type_signature,
                      debug_aggregator.initialize.type_signature)
     self.assertEqual(mean_aggregator.next.type_signature.parameter,
                      debug_aggregator.next.type_signature.parameter)
     self.assertEqual(mean_aggregator.next.type_signature.result.state,
                      debug_aggregator.next.type_signature.result.state)
     self.assertEqual(mean_aggregator.next.type_signature.result.result,
                      debug_aggregator.next.type_signature.result.result)
예제 #19
0
    def test_weighted_client(self):
        factory = mean.MeanFactory()

        factory = measurements.add_measurements(
            factory, client_measurement_fn=_get_min_weighted_norm)
        process = factory.create(_struct_type, _float_type)

        state = process.initialize()
        client_data = [_make_struct(x) for x in [1.0, 2.0, 3.0]]
        client_weights = [3.0, 1.0, 2.0]
        output = process.next(state, client_data, client_weights)
        self.assertAllClose(_make_struct(11 / 6), output.result)
        self.assertDictEqual(
            collections.OrderedDict(mean_value=(),
                                    mean_weight=(),
                                    min_weighted_norm=4.0),
            output.measurements)
예제 #20
0
  def test_structure_value(self):
    factory_ = mean.MeanFactory()
    value_type = computation_types.to_type(_test_struct_type)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)
    expected_state = collections.OrderedDict(
        value_sum_process=(), weight_sum_process=())
    expected_measurements = collections.OrderedDict(
        mean_value=(), mean_weight=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    weights = [3.0, 2.0, 1.0]
    output = process.next(state, client_data, weights)
    self.assertAllEqual(expected_state, output.state)
    self.assertAllClose(((10. / 6., 16. / 6.), 22. / 6.), output.result)
    self.assertEqual(expected_measurements, output.measurements)
def robust_aggregator(
    *,
    zeroing: bool = True,
    clipping: bool = True,
    weighted: bool = True,
    debug_measurements_fn: Optional[Callable[
        [factory.AggregationFactory], factory.AggregationFactory]] = None,
) -> factory.AggregationFactory:
    """Creates aggregator for mean with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and clips in the L2 norm to moderately high norm for robustness to
  outliers.

  For details on clipping and zeroing see `tff.aggregators.clipping_factory`
  and `tff.aggregators.zeroing_factory`. For details on the quantile-based
  adaptive algorithm see `tff.aggregators.PrivateQuantileEstimationProcess`.

  Args:
    zeroing: Whether to enable adaptive zeroing for data corruption mitigation.
    clipping: Whether to enable adaptive clipping in the L2 norm for robustness.
    weighted: Whether the mean is weighted (vs. unweighted).
    debug_measurements_fn: A callable to add measurements suitable for debugging
      learning algorithms. Often useful values include None,
      `tff.learning.add_debug_measurements` or
      `tff.learning.add_debug_measurements_with_mixed_dtype`.

  Returns:
    A `tff.aggregators.AggregationFactory`.
  """
    factory_ = mean.MeanFactory() if weighted else mean.UnweightedMeanFactory()

    if debug_measurements_fn:
        factory_ = debug_measurements_fn(factory_)

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
예제 #22
0
def test_aggregator():
    return mean.MeanFactory().create(FLOAT_TYPE, FLOAT_TYPE)
예제 #23
0
def _concat_mean():
    return concat.concat_factory(mean.MeanFactory())
예제 #24
0
def build_model_delta_optimizer_process(
    model_fn: _ModelConstructor,
    model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
                                       ClientDeltaFn],
    server_optimizer_fn: _OptimizerConstructor,
    *,
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    model_update_aggregation_factory: Optional[
        factory.AggregationFactory] = None,
) -> iterative_process.IterativeProcess:
    """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`. If set to default None,
      the server model is broadcast to the clients using the default
      tff.federated_broadcast.
    model_update_aggregation_factory: An optional
      `tff.aggregators.WeightedAggregationFactory` that contstructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation.

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ProcessTypeError: If `broadcast_process` does not conform to the signature
      of broadcast (SERVER->CLIENTS).
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)

    model_weights_type = model_utils.weights_type_from_model(model_fn)

    if broadcast_process is None:
        broadcast_process = build_stateless_broadcaster(
            model_weights_type=model_weights_type)
    if not is_valid_broadcast_process(broadcast_process):
        raise ProcessTypeError(
            'broadcast_process type signature does not conform to expected '
            'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
            ' Got: {t}'.format(t=broadcast_process.next.type_signature))

    if model_update_aggregation_factory is None:
        model_update_aggregation_factory = mean.MeanFactory()
    py_typecheck.check_type(model_update_aggregation_factory,
                            factory.AggregationFactory.__args__)
    if isinstance(model_update_aggregation_factory,
                  factory.WeightedAggregationFactory):
        aggregation_process = model_update_aggregation_factory.create(
            model_weights_type.trainable,
            computation_types.TensorType(tf.float32))
    else:
        aggregation_process = model_update_aggregation_factory.create(
            model_weights_type.trainable)
    process_signature = aggregation_process.next.type_signature
    input_client_value_type = process_signature.parameter[1]
    result_server_value_type = process_signature.result[1]
    if input_client_value_type.member != result_server_value_type.member:
        raise TypeError(
            '`model_update_aggregation_factory` does not produce a '
            'compatible `AggregationProcess`. The processes must '
            'retain the type structure of the inputs on the '
            f'server, but got {input_client_value_type.member} != '
            f'{result_server_value_type.member}.')

    initialize_computation = _build_initialize_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process)

    run_one_round_computation = _build_one_round_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        model_to_client_delta_fn=model_to_client_delta_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process)

    return iterative_process.IterativeProcess(
        initialize_fn=initialize_computation,
        next_fn=run_one_round_computation)
예제 #25
0
def build_basic_fedavg_process(model_fn: Callable[[], model_lib.Model],
                               client_learning_rate: float):
    """Builds vanilla Federated Averaging process.

  The created process is the basic form of the Federated Averaging algorithm as
  proposed by http://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf in
  Algorithm 1, for training the model created by `model_fn`. The following is
  the algorithm in pseudo-code:

  ```
  # Inputs: m: Initial model weights; eta: Client learning rate
  for i in num_rounds:
    for c in available_clients_indices:
      delta_m_c, w_c = client_update(m, eta)
    aggregate_model_delta = sum_c(model_delta_c * w_c) / sum_c(w_c)
    m = m - aggregate_model_delta
  return m  # Final trained model.

  def client_udpate(m, eta):
    initial_m = m
    for batch in client_dataset:
      m = m - eta * grad(m, b)
    delta_m = initial_m - m
    return delta_m, size(dataset)
  ```

  The other algorithm hyper parameters (batch size, number of local epochs) are
  controlled by the data provided to the built process.

  An example usage of the returned `LearningProcess` in simulation:

  ```
  fedavg = build_basic_fedavg_process(model_fn, 0.1)

  # Create a `LearningAlgorithmState` containing the initial model weights for
  # the model returned from `model_fn`.
  state = fedavg.initialize()
  for _ in range(num_rounds):
    client_data = ...  # Preprocessed client datasets
    output = fedavg.next(state, client_data)
    write_round_metrics(outpus.metrics)
    # The new state contains the updated model weights after this round.
    state = output.state
  ```

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_learning_rate: A float. Learning rate for the SGD at clients.

  Returns:
    A `LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(client_learning_rate, float)

    @computations.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result

    distributor = distributors.build_broadcast_process(model_weights_type)
    client_work = client_works.build_model_delta_client_work(
        model_fn, sgdm.build_sgdm(client_learning_rate))
    aggregator = mean.MeanFactory().create(
        client_work.next.type_signature.result.result.member.update,
        client_work.next.type_signature.result.result.member.update_weight)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        sgdm.build_sgdm(1.0), model_weights_type)

    return compose_learning_process(initial_model_weights_fn, distributor,
                                    client_work, aggregator, finalizer)
예제 #26
0
def build_weighted_mime_lite(
    model_fn: Callable[[], model_lib.Model],
    base_optimizer: optimizer_base.Optimizer,
    server_optimizer: optimizer_base.Optimizer = sgdm.build_sgdm(1.0),
    client_weighting: Optional[
        client_weight_lib.
        ClientWeighting] = client_weight_lib.ClientWeighting.NUM_EXAMPLES,
    model_distributor: Optional[distributors.DistributionProcess] = None,
    model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
    full_gradient_aggregator: Optional[
        factory.WeightedAggregationFactory] = None,
    metrics_aggregator: Optional[Callable[[
        model_lib.MetricFinalizersType, computation_types.StructWithPythonType
    ], computation_base.Computation]] = None,
    use_experimental_simulation_loop: bool = False
) -> learning_process.LearningProcess:
    """Builds a learning process that performs Mime Lite.

  This function creates a `tff.learning.templates.LearningProcess` that performs
  Mime Lite algorithm on client models. The iterative process has the following
  methods inherited from `tff.learning.templates.LearningProcess`:

  *   `initialize`: A `tff.Computation` with the functional type signature
      `( -> S@SERVER)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` representing the initial
      state of the server.
  *   `next`: A `tff.Computation` with the functional type signature
      `(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)` where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `{B*}@CLIENTS` represents the client datasets.
      The output `L` contains the updated server state, as well as aggregated
      metrics at the server, including client training metrics and any other
      metrics from distribution and aggregation processes.
  *   `get_model_weights`: A `tff.Computation` with type signature `(S -> M)`,
      where `S` is a `tff.learning.templates.LearningAlgorithmState` whose type
      matches the output of `initialize` and `next`, and `M` represents the type
      of the model weights used during training.
  *   `set_model_weights`: A `tff.Computation` with type signature
      `(<S, M> -> S)`, where `S` is a
      `tff.learning.templates.LearningAlgorithmState` whose type matches the
      output of `initialize` and `M` represents the type of the model weights
      used during training.

  Each time the `next` method is called, the server model is communicated to
  each client using the provided `model_distributor`. For each client, local
  training is performed using `optimizer`, where its state is communicated by
  the server, and kept intact during local training. The state is updated only
  at the server based on the full gradient evaluated by the clients based on the
  current server model state. The client full gradients are aggregated by
  weighted `full_gradient_aggregator`. Each client computes the difference
  between the client model after training and its initial model. These model
  deltas are then aggregated by weighted `model_aggregator`. Both of the
  aggregations are weighted, according to `client_weighting`. The aggregate
  model delta is added to the existing server model state.

  The Mime Lite algorithm is based on the paper
  "Breaking the centralized barrier for cross-device federated learning."
    Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank
    Reddi, Sebastian U. Stich, and Ananda Theertha Suresh.
    Advances in Neural Information Processing Systems 34 (2021).
    https://proceedings.neurips.cc/paper/2021/file/f0e6be4ce76ccfa73c5a540d992d0756-Paper.pdf

  Note that Keras optimizers are not supported. This is due to the Mime Lite
  algorithm applying the optimizer without changing it state at clients
  (optimizer's `tf.Variable`s in the case of Keras), which is not possible with
  Keras optimizers without reaching into private implementation details and
  incurring additional computation and memory cost at clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`. This method
      must *not* capture TensorFlow tensors or variables and use them. The model
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    base_optimizer: A `tff.learning.optimizers.Optimizer` which will be used for
      both creating and updating a global optimizer state, as well as
      optimization at clients given the global state, which is fixed during the
      optimization.
    server_optimizer: A `tff.learning.optimizers.Optimizer` which will be used
      for applying the aggregate model update to the global model weights.
    client_weighting: A member of `tff.learning.ClientWeighting` that specifies
      a built-in weighting method. By default, weighting by number of examples
      is used.
    model_distributor: An optional `DistributionProcess` that distributes the
      model weights on the server to the clients. If set to `None`, the
      distributor is constructed via `distributors.build_broadcast_process`.
    model_aggregator: An optional `tff.aggregators.WeightedAggregationFactory`
      used to aggregate client updates on the server. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    full_gradient_aggregator: An optional
      `tff.aggregators.WeightedAggregationFactory` used to aggregate the full
      gradients on client datasets. If `None`, this is set to
      `tff.aggregators.MeanFactory`.
    metrics_aggregator: A function that takes in the metric finalizers (i.e.,
      `tff.learning.Model.metric_finalizers()`) and a
      `tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
      type of `tff.learning.Model.report_local_unfinalized_metrics()`), and
      returns a `tff.Computation` for aggregating the unfinalized metrics. If
      `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation. It is
      currently necessary to set this flag to True for performant GPU
      simulations.

  Returns:
    A `tff.learning.templates.LearningProcess`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_type(base_optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(server_optimizer, optimizer_base.Optimizer)
    py_typecheck.check_type(client_weighting,
                            client_weight_lib.ClientWeighting)

    @tensorflow_computation.tf_computation()
    def initial_model_weights_fn():
        return model_utils.ModelWeights.from_model(model_fn())

    model_weights_type = initial_model_weights_fn.type_signature.result
    if model_distributor is None:
        model_distributor = distributors.build_broadcast_process(
            model_weights_type)
    if model_aggregator is None:
        model_aggregator = mean.MeanFactory()
    py_typecheck.check_type(model_aggregator,
                            factory.WeightedAggregationFactory)
    model_aggregator = model_aggregator.create(
        model_weights_type.trainable, computation_types.TensorType(tf.float32))
    if full_gradient_aggregator is None:
        full_gradient_aggregator = mean.MeanFactory()
    py_typecheck.check_type(full_gradient_aggregator,
                            factory.WeightedAggregationFactory)

    client_work = _build_mime_lite_client_work(
        model_fn=model_fn,
        optimizer=base_optimizer,
        client_weighting=client_weighting,
        full_gradient_aggregator=full_gradient_aggregator,
        metrics_aggregator=metrics_aggregator,
        use_experimental_simulation_loop=use_experimental_simulation_loop)
    finalizer = finalizers.build_apply_optimizer_finalizer(
        server_optimizer, model_weights_type)
    return composers.compose_learning_process(initial_model_weights_fn,
                                              model_distributor, client_work,
                                              model_aggregator, finalizer)
예제 #27
0
def _zeroed_mean(clip=2.0, norm_order=2.0):
    return robust.zeroing_factory(clip, mean.MeanFactory(), norm_order)
예제 #28
0
def _clipped_mean(clip=2.0):
    return robust.clipping_factory(clip, mean.MeanFactory())
예제 #29
0
 def test_as_weighted_raises(self):
   with self.assertRaises(TypeError):
     factory_utils.as_weighted_aggregator(mean.MeanFactory())
예제 #30
0
def build_model_delta_optimizer_process(
    model_fn: _ModelConstructor,
    model_to_client_delta_fn: Callable[[Callable[[], model_lib.Model]],
                                       ClientDeltaFn],
    server_optimizer_fn: _OptimizerConstructor,
    *,
    broadcast_process: Optional[measured_process.MeasuredProcess] = None,
    aggregation_process: Optional[measured_process.MeasuredProcess] = None,
    model_update_aggregation_factory: Optional[
        factory.AggregationFactory] = None,
) -> iterative_process.IterativeProcess:
    """Constructs `tff.templates.IterativeProcess` for Federated Averaging or SGD.

  This provides the TFF orchestration logic connecting the common server logic
  which applies aggregated model deltas to the server model with a
  `ClientDeltaFn` that specifies how `weight_deltas` are computed on device.

  Note: We pass in functions rather than constructed objects so we can ensure
  any variables or ops created in constructors are placed in the correct graph.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    model_to_client_delta_fn: A function from a `model_fn` to a `ClientDeltaFn`.
    server_optimizer_fn: A no-arg function that returns a `tf.Optimizer`. The
      `apply_gradients` method of this optimizer is used to apply client updates
      to the server model.
    broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
      model weights on the server to the clients. It must support the signature
      `(input_values@SERVER -> output_values@CLIENT)`.
    aggregation_process: A `tff.templates.MeasuredProcess` that aggregates the
      model updates on the clients back to the server. It must support the
      signature `({input_values}@CLIENTS-> output_values@SERVER)`. Must be
      `None` if `model_update_aggregation_factory` is not `None.`
    model_update_aggregation_factory: An optional
      `tff.aggregators.WeightedAggregationFactory` that contstructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. Must
      be `None` if `aggregation_process` is not `None.`

  Returns:
    A `tff.templates.IterativeProcess`.

  Raises:
    ProcessTypeError: if `broadcast_process` or `aggregation_process` do not
      conform to the signature of broadcast (SERVER->CLIENTS) or aggregation
      (CLIENTS->SERVER).
    DisjointArgumentError: if both `aggregation_process` and
      `model_update_aggregation_factory` are not `None`.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(model_to_client_delta_fn)
    py_typecheck.check_callable(server_optimizer_fn)

    model_weights_type = model_utils.weights_type_from_model(model_fn)

    if broadcast_process is None:
        broadcast_process = build_stateless_broadcaster(
            model_weights_type=model_weights_type)
    if not _is_valid_broadcast_process(broadcast_process):
        raise ProcessTypeError(
            'broadcast_process type signature does not conform to expected '
            'signature (<state@S, input@S> -> <state@S, result@C, measurements@S>).'
            ' Got: {t}'.format(t=broadcast_process.next.type_signature))

    if (model_update_aggregation_factory is not None
            and aggregation_process is not None):
        raise DisjointArgumentError(
            'Must specify only one of `model_update_aggregation_factory` and '
            '`AggregationProcess`.')

    if aggregation_process is None:
        if model_update_aggregation_factory is None:
            model_update_aggregation_factory = mean.MeanFactory()

        py_typecheck.check_type(model_update_aggregation_factory,
                                factory.AggregationFactory.__args__)

        if isinstance(model_update_aggregation_factory,
                      factory.WeightedAggregationFactory):
            aggregation_process = model_update_aggregation_factory.create(
                model_weights_type.trainable,
                computation_types.TensorType(tf.float32))
        else:
            aggregation_process = model_update_aggregation_factory.create(
                model_weights_type.trainable)
    else:
        next_num_args = len(aggregation_process.next.type_signature.parameter)
        if next_num_args not in [2, 3]:
            raise ValueError(
                f'`next` function of `aggregation_process` must take two (for '
                f'unweighted aggregation) or three (for weighted aggregation) '
                f'arguments. Found {next_num_args}.')

    if not _is_valid_aggregation_process(aggregation_process):
        raise ProcessTypeError(
            'aggregation_process type signature does not conform to expected '
            'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).'
            ' Got: {t}'.format(t=aggregation_process.next.type_signature))

    initialize_computation = _build_initialize_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process)

    run_one_round_computation = _build_one_round_computation(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        model_to_client_delta_fn=model_to_client_delta_fn,
        broadcast_process=broadcast_process,
        aggregation_process=aggregation_process)

    return iterative_process.IterativeProcess(
        initialize_fn=initialize_computation,
        next_fn=run_one_round_computation)