コード例 #1
0
def model_update_aggregator(
    zeroing: Optional[AdaptiveZeroingConfig] = AdaptiveZeroingConfig(),
    clipping_and_noise: Optional[Union[ClippingConfig,
                                       DifferentialPrivacyConfig]] = None
) -> AggregationFactory:
    """Builds aggregator for model updates in FL according to configs.

  The default aggregator (produced if no arguments are overridden) performs
  mean with adaptive zeroing for robustness. To turn off adaptive zeroing set
  `zeroing=None`. (Adaptive) clipping and/or differential privacy can
  optionally be enabled by setting `clipping_and_noise`.

  Args:
    zeroing: A ZeroingConfig. If None, no zeroing will be performed.
    clipping_and_noise: An optional ClippingConfig or DifferentialPrivacyConfig.
      If unspecified, no clipping or noising will be performed.

  Returns:
    A `factory.WeightedAggregationFactory` intended for model update aggregation
      in federated averaging with zeroing and clipping for robustness.
  """
    if not clipping_and_noise:
        factory_ = mean_factory.MeanFactory()
    elif isinstance(clipping_and_noise, ClippingConfig):
        factory_ = _apply_clipping(clipping_and_noise,
                                   mean_factory.MeanFactory())
    elif isinstance(clipping_and_noise, DifferentialPrivacyConfig):
        factory_ = _dp_factory(clipping_and_noise)
    else:
        raise TypeError(
            f'clipping_and_noise must be a supported type of clipping '
            f'or noise config. Found type {type(clipping_and_noise)}.')
    if zeroing:
        factory_ = _apply_zeroing(zeroing, factory_)
    return factory_
コード例 #2
0
def model_update_aggregator(
    zeroing: Optional[ZeroingConfig] = ZeroingConfig(),
    clipping_and_noise: Optional[Union[ClippingConfig, DPConfig]] = None
) -> Union[factory.WeightedAggregationFactory,
           factory.UnweightedAggregationFactory]:
  """Builds model update aggregator.

  Args:
    zeroing: A ZeroingConfig. If None, no zeroing will be performed.
    clipping_and_noise: An optional ClippingConfig or DPConfig. If unspecified,
      no clipping or noising will be performed.

  Returns:
    A `factory.WeightedAggregationFactory` intended for model update aggregation
      in federated averaging with zeroing and clipping for robustness.
  """
  if not clipping_and_noise:
    factory_ = mean_factory.MeanFactory()
  elif isinstance(clipping_and_noise, ClippingConfig):
    factory_ = clipping_and_noise.to_factory(mean_factory.MeanFactory())
  else:
    py_typecheck.check_type(clipping_and_noise, DPConfig, 'clipping_and_noise')
    factory_ = clipping_and_noise.to_factory()
  if zeroing:
    factory_ = zeroing.to_factory(factory_)
  return factory_
コード例 #3
0
def compression_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates aggregator with compression and adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are uniformly quantized to reduce the
  size of the model update communicated from clients to the server. For details,
  see Suresh et al. (2017)
  http://proceedings.mlr.press/v70/suresh17a/suresh17a.pdf. The default
  configuration is chosen such that compression does not have adverse effect on
  trained model quality in typical tasks.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    factory_ = mean_factory.MeanFactory(
        encoded_factory.EncodedSumFactory.quantize_above_threshold(
            quantization_bits=8, threshold=20000))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
コード例 #4
0
def robust_aggregator(
    zeroing: bool = True,
    clipping: bool = True) -> factory.WeightedAggregationFactory:
  """Creates aggregator for mean with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and clips to moderately high norm for robustness to outliers.

  For details on clipping and zeroing see `tff.aggregators.ClippingFactory` and
  `tff.aggregators.ZeroingFactory`. For details on the quantile-based adaptive
  algorithm see `tff.aggregators.PrivateQuantileEstimationProcess`.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory` with zeroing and clipping.
  """
  factory_ = mean_factory.MeanFactory()

  if clipping:
    # Adapts relatively quickly to a moderately high norm.
    clipping_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=1.0, target_quantile=0.8, learning_rate=0.2)
    factory_ = clipping_factory.ClippingFactory(clipping_norm, factory_)

  if zeroing:
    factory_ = _default_zeroing(factory_)

  return factory_
コード例 #5
0
def robust_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates aggregator for mean with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, and clips to moderately high norm for robustness to outliers.

  For details on clipping and zeroing see `tff.aggregators.clipping_factory`
  and `tff.aggregators.zeroing_factory`. For details on the quantile-based
  adaptive algorithm see `tff.aggregators.PrivateQuantileEstimationProcess`.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    factory_ = mean_factory.MeanFactory()

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
コード例 #6
0
  def test_type_properties_with_inner_factory_weighted(self, value_type,
                                                       weight_type):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean_factory.MeanFactory(
        value_sum_factory=sum_factory, weight_sum_factory=sum_factory)
    self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
    value_type = computation_types.to_type(value_type)
    weight_type = computation_types.to_type(weight_type)
    process = factory_.create_weighted(value_type, weight_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.FederatedType(value_type,
                                                       placements.CLIENTS)
    result_value_type = computation_types.FederatedType(value_type,
                                                        placements.SERVER)
    expected_state_type = expected_measurements_type = computation_types.FederatedType(
        collections.OrderedDict(
            value_sum_process=tf.int32, weight_sum_process=tf.int32),
        placements.SERVER)

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type,
            value=param_value_type,
            weight=computation_types.at_clients(weight_type)),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #7
0
    def test_type_properties(self, value_type):
        mean_f = mean_factory.MeanFactory()
        self.assertIsInstance(mean_f, factory.AggregationProcessFactory)
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.FederatedType(tf.float32,
                                                      placements.CLIENTS)
        process = mean_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        param_value_type = computation_types.FederatedType(
            value_type, placements.CLIENTS)
        result_value_type = computation_types.FederatedType(
            value_type, placements.SERVER)
        expected_state_type = computation_types.FederatedType(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()), placements.SERVER)
        expected_measurements_type = computation_types.FederatedType(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()), placements.SERVER)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(state=expected_state_type,
                                              value=param_value_type,
                                              weight=weight_type),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type,
                expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #8
0
  def test_inner_value_and_weight_sum_factory(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean_factory.MeanFactory(
        value_sum_factory=sum_factory, weight_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)

    state = process.initialize()
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=0, weight_sum_process=0),
        state)

    # Weighted values will be summed to 11.0 and weights will be summed to 7.0.
    client_data = [1.0, 2.0, 3.0]
    weights = [3.0, 2.0, 1.0]

    output = process.next(state, client_data, weights)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=1, weight_sum_process=1),
        output.state)
    self.assertAllClose(11 / 7, output.result)
    self.assertEqual(
        collections.OrderedDict(mean_value=M_CONST, mean_weight=M_CONST),
        output.measurements)
コード例 #9
0
def secure_aggregator(
        zeroing: bool = True,
        clipping: bool = True) -> factory.WeightedAggregationFactory:
    """Creates secure aggregator with adaptive zeroing and clipping.

  Zeroes out extremely large values for robustness to data corruption on
  clients, clips to moderately high norm for robustness to outliers. After
  weighting in mean, the weighted values are summed using cryptographic protocol
  ensuring that the server cannot see individual updates until sufficient number
  of updates have been added together. For details, see Bonawitz et al. (2017)
  https://dl.acm.org/doi/abs/10.1145/3133956.3133982. In TFF, this is realized
  using the `tff.federated_secure_sum` operator.

  Args:
    zeroing: Whether to enable adaptive zeroing.
    clipping: Whether to enable adaptive clipping.

  Returns:
    A `tff.aggregators.WeightedAggregationFactory`.
  """
    secure_clip_bound = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=50.0,
        target_quantile=0.95,
        learning_rate=1.0,
        multiplier=2.0)
    factory_ = mean_factory.MeanFactory(
        secure_factory.SecureSumFactory(secure_clip_bound))

    if clipping:
        factory_ = _default_clipping(factory_)

    if zeroing:
        factory_ = _default_zeroing(factory_)

    return factory_
コード例 #10
0
  def test_type_properties(self, value_type, weight_type):
    value_type = computation_types.to_type(value_type)
    weight_type = computation_types.to_type(weight_type)

    factory_ = mean_factory.MeanFactory()
    self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
    process = factory_.create(value_type, weight_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()))
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=(), mean_weight=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_parameter = collections.OrderedDict(
        state=expected_state_type,
        value=param_value_type,
        weight=computation_types.at_clients(weight_type))

    expected_next_type = computation_types.FunctionType(
        parameter=expected_parameter,
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #11
0
 def test_incorrect_create_type_raises(self, wrong_type):
     mean_f = mean_factory.MeanFactory()
     correct_type = computation_types.to_type(tf.float32)
     with self.assertRaises(TypeError):
         mean_f.create_weighted(wrong_type, correct_type)
     with self.assertRaises(TypeError):
         mean_f.create_weighted(correct_type, wrong_type)
コード例 #12
0
 def test_raises_zeroing_norm_fn_bad_arg(self):
     zeroing_norm_fn = computations.tf_computation(lambda x: x + 3,
                                                   tf.int32)
     with self.assertRaisesRegex(TypeError,
                                 'Argument of `zeroing_norm_fn`'):
         clipping_factory.ZeroingClippingFactory(2.0, zeroing_norm_fn,
                                                 mean_factory.MeanFactory())
コード例 #13
0
 def test_apply_adaptive_clipping(self):
     factory_ = model_update_aggregator._apply_clipping(
         model_update_aggregator.AdaptiveClippingConfig(),
         mean_factory.MeanFactory())
     self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
     process = factory_.create_weighted(_test_type, _test_type)
     self.assertIsInstance(process, aggregation_process.AggregationProcess)
コード例 #14
0
 def test_zeroing_config_to_factory(self):
   qe_config = model_update_aggregator.QuantileEstimationConfig(
       initial_estimate=1.0, target_quantile=0.5, learning_rate=1.0)
   zeroing_config = model_update_aggregator.ZeroingConfig(
       quantile=qe_config, multiplier=10.0, increment=0.5)
   factory_ = zeroing_config.to_factory(mean_factory.MeanFactory())
   self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
コード例 #15
0
    def test_weight_arg_all_zeros_no_nan_division(self):
        mean_f = mean_factory.MeanFactory(no_nan_division=True)
        value_type = computation_types.to_type(tf.float32)
        process = mean_f.create(value_type)

        state = process.initialize()
        client_data = [1.0, 2.0, 3.0]
        weights = [0.0, 0.0, 0.0]
        # Division by zero resulting in NaN/Inf *should not* occur.
        self.assertEqual(0.0, process.next(state, client_data, weights).result)
コード例 #16
0
  def test_weight_arg_all_zeros_nan_division(self):
    factory_ = mean_factory.MeanFactory(no_nan_division=False)
    value_type = computation_types.to_type(tf.float32)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create_weighted(value_type, weight_type)

    state = process.initialize()
    client_data = [1.0, 2.0, 3.0]
    weights = [0.0, 0.0, 0.0]
    # Division by zero resulting in NaN/Inf *should* occur.
    self.assertFalse(
        math.isfinite(process.next(state, client_data, weights).result))
コード例 #17
0
    def test_weight_arg(self):
        mean_f = mean_factory.MeanFactory()
        value_type = computation_types.to_type(tf.float32)
        process = mean_f.create(value_type)

        state = process.initialize()
        client_data = [1.0, 2.0, 3.0]
        weights = [1.0, 1.0, 1.0]
        self.assertEqual(2.0, process.next(state, client_data, weights).result)
        weights = [0.1, 0.1, 0.1]
        self.assertEqual(2.0, process.next(state, client_data, weights).result)
        weights = [6.0, 3.0, 1.0]
        self.assertEqual(1.5, process.next(state, client_data, weights).result)
コード例 #18
0
  def test_structure_value_unweighted(self):
    factory_ = mean_factory.MeanFactory()
    value_type = computation_types.to_type(_test_struct_type)
    process = factory_.create_unweighted(value_type)
    expected_state = expected_measurements = collections.OrderedDict(
        value_sum_process=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    output = process.next(state, client_data)

    self.assertAllEqual(expected_state, output.state)
    self.assertAllClose(((2.0, 7 / 3), 4.0), output.result)
    self.assertEqual(expected_measurements, output.measurements)
コード例 #19
0
  def test_scalar_value_unweighted(self):
    factory_ = mean_factory.MeanFactory()
    value_type = computation_types.to_type(tf.float32)

    process = factory_.create_unweighted(value_type)
    expected_state = expected_measurements = collections.OrderedDict(
        value_sum_process=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [1.0, 2.0, 3.0]
    output = process.next(state, client_data)
    self.assertAllClose(2.0, output.result)

    self.assertAllEqual(expected_state, output.state)
    self.assertEqual(expected_measurements, output.measurements)
コード例 #20
0
  def test_inner_value_sum_factory_unweighted(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean_factory.MeanFactory(value_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    process = factory_.create_unweighted(value_type)

    state = process.initialize()
    self.assertAllEqual(collections.OrderedDict(value_sum_process=0), state)

    client_data = [1.0, 2.0, 3.0]
    # Values will be summed to 7.0.
    output = process.next(state, client_data)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=1), output.state)
    self.assertAllClose(7 / 3, output.result)
    self.assertEqual(
        collections.OrderedDict(value_sum_process=M_CONST), output.measurements)
コード例 #21
0
  def test_structure_value(self):
    factory_ = mean_factory.MeanFactory()
    value_type = computation_types.to_type(_test_struct_type)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)
    expected_state = collections.OrderedDict(
        value_sum_process=(), weight_sum_process=())
    expected_measurements = collections.OrderedDict(
        mean_value=(), mean_weight=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    weights = [3.0, 2.0, 1.0]
    output = process.next(state, client_data, weights)
    self.assertAllEqual(expected_state, output.state)
    self.assertAllClose(((10. / 6., 16. / 6.), 22. / 6.), output.result)
    self.assertEqual(expected_measurements, output.measurements)
コード例 #22
0
  def test_structure_value(self):
    mean_f = mean_factory.MeanFactory()
    value_type = computation_types.to_type(((tf.float32, (2,)), tf.float64))
    process = mean_f.create(value_type)

    state = process.initialize()
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    weights = [1.0, 1.0, 1.0]
    output = process.next(state, client_data, weights)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        output.state)
    self.assertAllClose(((2.0, 7 / 3), 4.0), output.result)
    self.assertEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        output.measurements)
コード例 #23
0
  def test_scalar_value(self):
    mean_f = mean_factory.MeanFactory()
    value_type = computation_types.to_type(tf.float32)
    process = mean_f.create(value_type)

    state = process.initialize()
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        state)

    client_data = [1.0, 2.0, 3.0]
    weights = [1.0, 1.0, 1.0]
    output = process.next(state, client_data, weights)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        output.state)
    self.assertAllClose(2.0, output.result)
    self.assertEqual(
        collections.OrderedDict(value_sum_process=(), weight_sum_process=()),
        output.measurements)
コード例 #24
0
def adaptive_zeroing_mean(
        initial_quantile_estimate: float,
        target_quantile: float,
        multiplier: float,
        increment: float,
        learning_rate: float,
        norm_order: bool,
        no_nan_mean: bool = False) -> factory.WeightedAggregationFactory:
    """Creates a factory for mean with adaptive zeroing.

  Estimates value at quantile `Z` of value norm distribution and zeroes out
  values whose norm is greater than `rZ + i` for multiplier `r` and increment
  `i`. The quantile `Z` is estimated using the geometric method described in
  Thakkar et al. 2019, "Differentially Private Learning with Adaptive Clipping"
  (https://arxiv.org/abs/1905.03871) without noise added (so not differentially
  private).

  Args:
    initial_quantile_estimate: The initial estimate of the target quantile `Z`.
    target_quantile: Which quantile to match, as a float in [0, 1]. For example,
      0.5 for median, or 0.98 to zero out only the largest 2% of updates (if
      multiplier=1 and increment=0).
    multiplier: Factor `r` in zeroing norm formula `rZ + i`.
    increment: Increment `i` in zeroing norm formula `rZ + i`.
    learning_rate: Learning rate for quantile matching algorithm.
    norm_order: A float for the order of the norm. Must be 1, 2, or np.inf.
    no_nan_mean: A bool. If True, the computed mean is 0 if sum of weights is
      equal to 0.

  Returns:
    A factory that performs mean after adaptive clipping.
  """

    zeroing_quantile = _make_quantile_estimation_process(
        initial_estimate=initial_quantile_estimate,
        target_quantile=target_quantile,
        learning_rate=learning_rate)
    zeroing_norm = zeroing_quantile.map(
        _affine_transform(multiplier, increment))
    mean = mean_factory.MeanFactory(no_nan_division=no_nan_mean)
    return clipping_factory.ZeroingFactory(zeroing_norm, mean, norm_order)
コード例 #25
0
    def test_inner_weight_sum_factory(self):
        sum_factory = aggregators_test_utils.SumPlusOneFactory()
        mean_f = mean_factory.MeanFactory(weight_sum_factory=sum_factory)
        value_type = computation_types.to_type(tf.float32)
        process = mean_f.create(value_type)

        state = process.initialize()
        self.assertAllEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=0), state)

        client_data = [1.0, 2.0, 3.0]
        weights = [1.0, 1.0, 1.0]
        # Weights will be summed to 4.0.
        output = process.next(state, client_data, weights)
        self.assertAllEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=1), output.state)
        self.assertAllClose(1.5, output.result)
        self.assertEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=M_CONST),
            output.measurements)
コード例 #26
0
def _zeroed_mean(clip=2.0, norm_order=2.0):
  return clipping_factory.ZeroingFactory(clip, mean_factory.MeanFactory(),
                                         norm_order)
コード例 #27
0
def _clipped_mean(clip=2.0):
  return clipping_factory.ClippingFactory(clip, mean_factory.MeanFactory())
コード例 #28
0
def _zeroed_mean(clip=2.0, norm_order=2.0):
    return robust_factory.zeroing_factory(clip, mean_factory.MeanFactory(),
                                          norm_order)
コード例 #29
0
def _clipped_mean(clip=2.0):
    return robust_factory.clipping_factory(clip, mean_factory.MeanFactory())
コード例 #30
0
 def test_incorrect_value_type_raises(self, bad_value_type):
     mean_f = mean_factory.MeanFactory()
     with self.assertRaises(TypeError):
         mean_f.create(bad_value_type)