def model_update_aggregator(
    zeroing: Optional[AdaptiveZeroingConfig] = AdaptiveZeroingConfig(),
    clipping_and_noise: Optional[Union[ClippingConfig,
                                       DifferentialPrivacyConfig]] = None
) -> AggregationFactory:
    """Builds aggregator for model updates in FL according to configs.

  The default aggregator (produced if no arguments are overridden) performs
  mean with adaptive zeroing for robustness. To turn off adaptive zeroing set
  `zeroing=None`. (Adaptive) clipping and/or differential privacy can
  optionally be enabled by setting `clipping_and_noise`.

  Args:
    zeroing: A ZeroingConfig. If None, no zeroing will be performed.
    clipping_and_noise: An optional ClippingConfig or DifferentialPrivacyConfig.
      If unspecified, no clipping or noising will be performed.

  Returns:
    A `factory.WeightedAggregationFactory` intended for model update aggregation
      in federated averaging with zeroing and clipping for robustness.
  """
    if not clipping_and_noise:
        factory_ = mean_factory.MeanFactory()
    elif isinstance(clipping_and_noise, ClippingConfig):
        factory_ = _apply_clipping(clipping_and_noise,
                                   mean_factory.MeanFactory())
    elif isinstance(clipping_and_noise, DifferentialPrivacyConfig):
        factory_ = _dp_factory(clipping_and_noise)
    else:
        raise TypeError(
            f'clipping_and_noise must be a supported type of clipping '
            f'or noise config. Found type {type(clipping_and_noise)}.')
    if zeroing:
        factory_ = _apply_zeroing(zeroing, factory_)
    return factory_
def model_update_aggregator(
    zeroing: Optional[ZeroingConfig] = ZeroingConfig(),
    clipping_and_noise: Optional[Union[ClippingConfig, DPConfig]] = None
) -> Union[factory.WeightedAggregationFactory,
           factory.UnweightedAggregationFactory]:
  """Builds model update aggregator.

  Args:
    zeroing: A ZeroingConfig. If None, no zeroing will be performed.
    clipping_and_noise: An optional ClippingConfig or DPConfig. If unspecified,
      no clipping or noising will be performed.

  Returns:
    A `factory.WeightedAggregationFactory` intended for model update aggregation
      in federated averaging with zeroing and clipping for robustness.
  """
  if not clipping_and_noise:
    factory_ = mean_factory.MeanFactory()
  elif isinstance(clipping_and_noise, ClippingConfig):
    factory_ = clipping_and_noise.to_factory(mean_factory.MeanFactory())
  else:
    py_typecheck.check_type(clipping_and_noise, DPConfig, 'clipping_and_noise')
    factory_ = clipping_and_noise.to_factory()
  if zeroing:
    factory_ = zeroing.to_factory(factory_)
  return factory_
def compression_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates aggregator with compression and adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are uniformly quantized to reduce the
  size of the model update communicated from clients to the server. For details,
  see Suresh et al. (2017)
  http://proceedings.mlr.press/v70/suresh17a/suresh17a.pdf. The default
  configuration is chosen such that compression does not have adverse effect on
  trained model quality in typical tasks.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    factory_ = mean_factory.MeanFactory(
        encoded_factory.EncodedSumFactory.quantize_above_threshold(
            quantization_bits=8, threshold=20000))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
def robust_aggregator(
    zeroing: bool = True,
    clipping: bool = True) -> factory.WeightedAggregationFactory:
  """Creates aggregator for mean with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and clips to moderately high norm for robustness to outliers.

  For details on clipping and zeroing see `tff.aggregators.ClippingFactory` and
  `tff.aggregators.ZeroingFactory`. For details on the quantile-based adaptive
  algorithm see `tff.aggregators.PrivateQuantileEstimationProcess`.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory` with zeroing and clipping.
  """
  factory_ = mean_factory.MeanFactory()

  if clipping:
    # Adapts relatively quickly to a moderately high norm.
    clipping_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=1.0, target_quantile=0.8, learning_rate=0.2)
    factory_ = clipping_factory.ClippingFactory(clipping_norm, factory_)

  if zeroing:
    factory_ = _default_zeroing(factory_)

  return factory_
def robust_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates aggregator for mean with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and clips to moderately high norm for robustness to outliers.

  For details on clipping and zeroing see `tff.aggregators.clipping_factory`
  and `tff.aggregators.zeroing_factory`. For details on the quantile-based
  adaptive algorithm see `tff.aggregators.PrivateQuantileEstimationProcess`.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    factory_ = mean_factory.MeanFactory()

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
  def test_type_properties_with_inner_factory_weighted(self, value_type,
                                                       weight_type):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean_factory.MeanFactory(
        value_sum_factory=sum_factory, weight_sum_factory=sum_factory)
    self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
    value_type = computation_types.to_type(value_type)
    weight_type = computation_types.to_type(weight_type)
    process = factory_.create_weighted(value_type, weight_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.FederatedType(value_type,
                                                       placements.CLIENTS)
    result_value_type = computation_types.FederatedType(value_type,
                                                        placements.SERVER)
    expected_state_type = expected_measurements_type = computation_types.FederatedType(
        collections.OrderedDict(
            value_sum_process=tf.int32, weight_sum_process=tf.int32),
        placements.SERVER)

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type,
            value=param_value_type,
            weight=computation_types.at_clients(weight_type)),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
Exemple #7
0
    def test_type_properties(self, value_type):
        mean_f = mean_factory.MeanFactory()
        self.assertIsInstance(mean_f, factory.AggregationProcessFactory)
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.FederatedType(tf.float32,
                                                      placements.CLIENTS)
        process = mean_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        param_value_type = computation_types.FederatedType(
            value_type, placements.CLIENTS)
        result_value_type = computation_types.FederatedType(
            value_type, placements.SERVER)
        expected_state_type = computation_types.FederatedType(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()), placements.SERVER)
        expected_measurements_type = computation_types.FederatedType(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()), placements.SERVER)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(state=expected_state_type,
                                              value=param_value_type,
                                              weight=weight_type),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type,
                expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Exemple #8
0
  def test_inner_value_and_weight_sum_factory(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean_factory.MeanFactory(
        value_sum_factory=sum_factory, weight_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)

    state = process.initialize()
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=0, weight_sum_process=0),
        state)

    # Weighted values will be summed to 11.0 and weights will be summed to 7.0.
    client_data = [1.0, 2.0, 3.0]
    weights = [3.0, 2.0, 1.0]

    output = process.next(state, client_data, weights)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=1, weight_sum_process=1),
        output.state)
    self.assertAllClose(11 / 7, output.result)
    self.assertEqual(
        collections.OrderedDict(mean_value=M_CONST, mean_weight=M_CONST),
        output.measurements)
def secure_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates secure aggregator with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are summed using cryptographic protocol
  ensuring that the server cannot see individual updates until sufficient number
  of updates have been added together. For details, see Bonawitz et al. (2017)
  https://dl.acm.org/doi/abs/10.1145/3133956.3133982. In TFF, this is realized
  using the `tff.federated_secure_sum` operator.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    secure_clip_bound = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=50.0,
        target_quantile=0.95,
        learning_rate=1.0,
        multiplier=2.0)
    factory_ = mean_factory.MeanFactory(
        secure_factory.SecureSumFactory(secure_clip_bound))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
Exemple #10
0
  def test_type_properties(self, value_type, weight_type):
    value_type = computation_types.to_type(value_type)
    weight_type = computation_types.to_type(weight_type)

    factory_ = mean_factory.MeanFactory()
    self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
    process = factory_.create(value_type, weight_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()))
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=(), mean_weight=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_parameter = collections.OrderedDict(
        state=expected_state_type,
        value=param_value_type,
        weight=computation_types.at_clients(weight_type))

    expected_next_type = computation_types.FunctionType(
        parameter=expected_parameter,
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
 def test_incorrect_create_type_raises(self, wrong_type):
     mean_f = mean_factory.MeanFactory()
     correct_type = computation_types.to_type(tf.float32)
     with self.assertRaises(TypeError):
         mean_f.create_weighted(wrong_type, correct_type)
     with self.assertRaises(TypeError):
         mean_f.create_weighted(correct_type, wrong_type)
Exemple #12
0
 def test_raises_zeroing_norm_fn_bad_arg(self):
     zeroing_norm_fn = computations.tf_computation(lambda x: x + 3,
                                                   tf.int32)
     with self.assertRaisesRegex(TypeError,
                                 'Argument of `zeroing_norm_fn`'):
         clipping_factory.ZeroingClippingFactory(2.0, zeroing_norm_fn,
                                                 mean_factory.MeanFactory())
Exemple #13
0
 def test_apply_adaptive_clipping(self):
     factory_ = model_update_aggregator._apply_clipping(
         model_update_aggregator.AdaptiveClippingConfig(),
         mean_factory.MeanFactory())
     self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
     process = factory_.create_weighted(_test_type, _test_type)
     self.assertIsInstance(process, aggregation_process.AggregationProcess)
 def test_zeroing_config_to_factory(self):
   qe_config = model_update_aggregator.QuantileEstimationConfig(
       initial_estimate=1.0, target_quantile=0.5, learning_rate=1.0)
   zeroing_config = model_update_aggregator.ZeroingConfig(
       quantile=qe_config, multiplier=10.0, increment=0.5)
   factory_ = zeroing_config.to_factory(mean_factory.MeanFactory())
   self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
    def test_weight_arg_all_zeros_no_nan_division(self):
        mean_f = mean_factory.MeanFactory(no_nan_division=True)
        value_type = computation_types.to_type(tf.float32)
        process = mean_f.create(value_type)

        state = process.initialize()
        client_data = [1.0, 2.0, 3.0]
        weights = [0.0, 0.0, 0.0]
        # Division by zero resulting in NaN/Inf *should not* occur.
        self.assertEqual(0.0, process.next(state, client_data, weights).result)
  def test_weight_arg_all_zeros_nan_division(self):
    factory_ = mean_factory.MeanFactory(no_nan_division=False)
    value_type = computation_types.to_type(tf.float32)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create_weighted(value_type, weight_type)

    state = process.initialize()
    client_data = [1.0, 2.0, 3.0]
    weights = [0.0, 0.0, 0.0]
    # Division by zero resulting in NaN/Inf *should* occur.
    self.assertFalse(
        math.isfinite(process.next(state, client_data, weights).result))
Exemple #17
0
    def test_weight_arg(self):
        mean_f = mean_factory.MeanFactory()
        value_type = computation_types.to_type(tf.float32)
        process = mean_f.create(value_type)

        state = process.initialize()
        client_data = [1.0, 2.0, 3.0]
        weights = [1.0, 1.0, 1.0]
        self.assertEqual(2.0, process.next(state, client_data, weights).result)
        weights = [0.1, 0.1, 0.1]
        self.assertEqual(2.0, process.next(state, client_data, weights).result)
        weights = [6.0, 3.0, 1.0]
        self.assertEqual(1.5, process.next(state, client_data, weights).result)
  def test_structure_value_unweighted(self):
    factory_ = mean_factory.MeanFactory()
    value_type = computation_types.to_type(_test_struct_type)
    process = factory_.create_unweighted(value_type)
    expected_state = expected_measurements = collections.OrderedDict(
        value_sum_process=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    output = process.next(state, client_data)

    self.assertAllEqual(expected_state, output.state)
    self.assertAllClose(((2.0, 7 / 3), 4.0), output.result)
    self.assertEqual(expected_measurements, output.measurements)
  def test_scalar_value_unweighted(self):
    factory_ = mean_factory.MeanFactory()
    value_type = computation_types.to_type(tf.float32)

    process = factory_.create_unweighted(value_type)
    expected_state = expected_measurements = collections.OrderedDict(
        value_sum_process=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [1.0, 2.0, 3.0]
    output = process.next(state, client_data)
    self.assertAllClose(2.0, output.result)

    self.assertAllEqual(expected_state, output.state)
    self.assertEqual(expected_measurements, output.measurements)
  def test_inner_value_sum_factory_unweighted(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean_factory.MeanFactory(value_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    process = factory_.create_unweighted(value_type)

    state = process.initialize()
    self.assertAllEqual(collections.OrderedDict(value_sum_process=0), state)

    client_data = [1.0, 2.0, 3.0]
    # Values will be summed to 7.0.
    output = process.next(state, client_data)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=1), output.state)
    self.assertAllClose(7 / 3, output.result)
    self.assertEqual(
        collections.OrderedDict(value_sum_process=M_CONST), output.measurements)
Exemple #21
0
  def test_structure_value(self):
    factory_ = mean_factory.MeanFactory()
    value_type = computation_types.to_type(_test_struct_type)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)
    expected_state = collections.OrderedDict(
        value_sum_process=(), weight_sum_process=())
    expected_measurements = collections.OrderedDict(
        mean_value=(), mean_weight=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    weights = [3.0, 2.0, 1.0]
    output = process.next(state, client_data, weights)
    self.assertAllEqual(expected_state, output.state)
    self.assertAllClose(((10. / 6., 16. / 6.), 22. / 6.), output.result)
    self.assertEqual(expected_measurements, output.measurements)
  def test_structure_value(self):
    mean_f = mean_factory.MeanFactory()
    value_type = computation_types.to_type(((tf.float32, (2,)), tf.float64))
    process = mean_f.create(value_type)

    state = process.initialize()
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    weights = [1.0, 1.0, 1.0]
    output = process.next(state, client_data, weights)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        output.state)
    self.assertAllClose(((2.0, 7 / 3), 4.0), output.result)
    self.assertEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        output.measurements)
  def test_scalar_value(self):
    mean_f = mean_factory.MeanFactory()
    value_type = computation_types.to_type(tf.float32)
    process = mean_f.create(value_type)

    state = process.initialize()
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        state)

    client_data = [1.0, 2.0, 3.0]
    weights = [1.0, 1.0, 1.0]
    output = process.next(state, client_data, weights)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        output.state)
    self.assertAllClose(2.0, output.result)
    self.assertEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        output.measurements)
Exemple #24
0
def adaptive_zeroing_mean(
        initial_quantile_estimate: float,
        target_quantile: float,
        multiplier: float,
        increment: float,
        learning_rate: float,
        norm_order: bool,
        no_nan_mean: bool = False) -> factory.WeightedAggregationFactory:
    """Creates a factory for mean with adaptive zeroing.

  Estimates value at quantile `Z` of value norm distribution and zeroes out
  values whose norm is greater than `rZ + i` for multiplier `r` and increment
  `i`. The quantile `Z` is estimated using the geometric method described in
  Thakkar et al. 2019, "Differentially Private Learning with Adaptive Clipping"
  (https://arxiv.org/abs/1905.03871) without noise added (so not differentially
  private).

  Args:
    initial_quantile_estimate: The initial estimate of the target quantile `Z`.
    target_quantile: Which quantile to match, as a float in [0, 1]. For example,
      0.5 for median, or 0.98 to zero out only the largest 2% of updates (if
      multiplier=1 and increment=0).
    multiplier: Factor `r` in zeroing norm formula `rZ + i`.
    increment: Increment `i` in zeroing norm formula `rZ + i`.
    learning_rate: Learning rate for quantile matching algorithm.
    norm_order: A float for the order of the norm. Must be 1, 2, or np.inf.
    no_nan_mean: A bool. If True, the computed mean is 0 if sum of weights is
      equal to 0.

  Returns:
    A factory that performs mean after adaptive clipping.
  """

    zeroing_quantile = _make_quantile_estimation_process(
        initial_estimate=initial_quantile_estimate,
        target_quantile=target_quantile,
        learning_rate=learning_rate)
    zeroing_norm = zeroing_quantile.map(
        _affine_transform(multiplier, increment))
    mean = mean_factory.MeanFactory(no_nan_division=no_nan_mean)
    return clipping_factory.ZeroingFactory(zeroing_norm, mean, norm_order)
Exemple #25
0
    def test_inner_weight_sum_factory(self):
        sum_factory = aggregators_test_utils.SumPlusOneFactory()
        mean_f = mean_factory.MeanFactory(weight_sum_factory=sum_factory)
        value_type = computation_types.to_type(tf.float32)
        process = mean_f.create(value_type)

        state = process.initialize()
        self.assertAllEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=0), state)

        client_data = [1.0, 2.0, 3.0]
        weights = [1.0, 1.0, 1.0]
        # Weights will be summed to 4.0.
        output = process.next(state, client_data, weights)
        self.assertAllEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=1), output.state)
        self.assertAllClose(1.5, output.result)
        self.assertEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=M_CONST),
            output.measurements)
def _zeroed_mean(clip=2.0, norm_order=2.0):
  return clipping_factory.ZeroingFactory(clip, mean_factory.MeanFactory(),
                                         norm_order)
def _clipped_mean(clip=2.0):
  return clipping_factory.ClippingFactory(clip, mean_factory.MeanFactory())
def _zeroed_mean(clip=2.0, norm_order=2.0):
    return robust_factory.zeroing_factory(clip, mean_factory.MeanFactory(),
                                          norm_order)
def _clipped_mean(clip=2.0):
    return robust_factory.clipping_factory(clip, mean_factory.MeanFactory())
Exemple #30
0
 def test_incorrect_value_type_raises(self, bad_value_type):
     mean_f = mean_factory.MeanFactory()
     with self.assertRaises(TypeError):
         mean_f.create(bad_value_type)