Пример #1
0
 def test_raise_on_invalid_clip_type(self, value):
   with self.assertRaises(TypeError):
     modular_clipping_factory.ModularClippingSumFactory(
         clip_range_lower=value, clip_range_upper=2)
   with self.assertRaises(TypeError):
     modular_clipping_factory.ModularClippingSumFactory(
         clip_range_lower=-2, clip_range_upper=value)
Пример #2
0
  def test_type_properties_simple(self):
    value_type = computation_types.to_type((tf.int32, (2,)))
    agg_factory = modular_clipping_factory.ModularClippingSumFactory(
        clip_range_lower=-2, clip_range_upper=2)
    process = agg_factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    # Inner SumFactory has no state.
    server_state_type = computation_types.at_server(())

    expected_init_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(expected_init_type))

    expected_measurements_type = collections.OrderedDict(modclip=())

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=computation_types.at_server(
                expected_measurements_type)))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
Пример #3
0
 def test_clip_sum(self, clip_range_lower, clip_range_upper, client_data,
                   expected_sum):
   agg_factory = modular_clipping_factory.ModularClippingSumFactory(
       clip_range_lower, clip_range_upper)
   value_type = computation_types.to_type(tf.int32)
   process = agg_factory.create(value_type)
   state = process.initialize()
   output = process.next(state, client_data)
   self.assertEqual(output.result, expected_sum)
Пример #4
0
 def test_clip_sum_struct(self, clip_range_lower, clip_range_upper,
                          client_data, expected_sum):
   agg_factory = modular_clipping_factory.ModularClippingSumFactory(
       clip_range_lower, clip_range_upper)
   value_type = computation_types.to_type((tf.int32, (2,)))
   process = agg_factory.create(value_type)
   state = process.initialize()
   client_tensor_data = [
       tf.constant(v, dtype=tf.int32, shape=(2,)) for v in client_data
   ]
   output = process.next(state, client_tensor_data)
   self.assertAllClose(
       tf.constant(expected_sum, dtype=tf.int32, shape=(2,)), output.result)
Пример #5
0
 def test_component_tensor_dtypes_raise_on(self, value_type):
   agg_factory = modular_clipping_factory.ModularClippingSumFactory(
       clip_range_lower=-2, clip_range_upper=2)
   value_type = computation_types.to_type(value_type)
   with self.assertRaisesRegex(TypeError, 'must all be integers'):
     agg_factory.create(value_type)
Пример #6
0
 def test_tff_value_types_raise_on(self, value_type):
   agg_factory = modular_clipping_factory.ModularClippingSumFactory(
       clip_range_lower=-2, clip_range_upper=2)
   value_type = computation_types.to_type(value_type)
   with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'):
     agg_factory.create(value_type)
Пример #7
0
 def test_raise_on_clip_range(self, lower, upper):
   with self.assertRaises(ValueError):
     modular_clipping_factory.ModularClippingSumFactory(
         clip_range_lower=lower, clip_range_upper=upper)
Пример #8
0
def create_hierarchical_histogram_aggregation_factory(
        num_bins: int,
        arity: int = 2,
        clip_mechanism: str = 'sub-sampling',
        max_records_per_user: int = 10,
        dp_mechanism: str = 'no-noise',
        noise_multiplier: float = 0.0,
        expected_clients_per_round: int = 10,
        bits: int = 22,
        enable_secure_sum: bool = True):
    """Creates hierarchical histogram aggregation factory.

  Hierarchical histogram factory is constructed by composing 3 aggregation
  factories.
  (1) The inner-most factory is `SumFactory`.
  (2) The middle factory is `DifferentiallyPrivateFactory` whose inner query is
      `TreeRangeSumQuery`. This factory 1) takes in a clipped histogram,
      constructs the hierarchical histogram and checks the norm bound of the
      hierarchical histogram at clients, 2) adds noise either at clients or at
      server according to `dp_mechanism`.
  (3) The outer-most factory is `HistogramClippingSumFactory` which clips the
      input histogram to bound each user's contribution.

  Args:
    num_bins: An `int` representing the input histogram size.
    arity: An `int` representing the branching factor of the tree. Defaults to
      2.
   clip_mechanism: A `str` representing the clipping mechanism. Currently
     supported mechanisms are
      - 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user`
        records without replacement from the client dataset.
      - 'distinct': Uniquify client dataset and uniformly sample up to
        `max_records_per_user` records without replacement from it.
    max_records_per_user: An `int` representing the maximum of records each user
      can include in their local histogram. Defaults to 10.
    dp_mechanism: A `str` representing the differentially private mechanism to
      use. Currently supported mechanisms are
      - 'no-noise': (Default) Tree aggregation mechanism without noise.
      - 'central-gaussian': Tree aggregation with central Gaussian mechanism.
      - 'distributed-discrete-gaussian': Tree aggregation mechanism with
        distributed discrete Gaussian mechanism in "The Distributed Discrete
        Gaussian Mechanism for Federated Learning with Secure Aggregation. Peter
        Kairouz, Ziyu Liu, Thomas Steinke".
    noise_multiplier: A `float` specifying the noise multiplier (central noise
      stddev / L2 clip norm) for model updates. Only needed when `dp_mechanism`
      is not 'no-noise'. Defaults to 0.0.
    expected_clients_per_round: An `int` specifying the lower bound of the
      expected number of clients. Only needed when `dp_mechanism` is
      'distributed-discrete-gaussian. Defaults to 10.
    bits: A positive integer specifying the communication bit-width B (where
      2**B will be the field size for SecAgg operations). Only needed when
      `dp_mechanism` is 'distributed-discrete-gaussian'. Please read the below
      precautions carefully and set `bits` accordingly. Otherwise, unexpected
      overflow or accuracy degradation might happen. (1) Should be in the
      inclusive range [1, 22] to avoid overflow inside secure aggregation; (2)
      Should be at least as large as `log2(4 * sqrt(expected_clients_per_round)*
      noise_multiplier * l2_norm_bound + expected_clients_per_round *
      max_records_per_user) + 1` to avoid accuracy degradation caused by
      frequent modular clipping; (3) If the number of clients exceed
      `expected_clients_per_round`, overflow might happen.
    enable_secure_sum: Whether to aggregate client's update by secure sum or
      not. Defaults to `True`. When `dp_mechanism` is set to
      `'distributed-discrete-gaussian'`, `enable_secure_sum` must be `True`.

  Returns:
    `tff.aggregators.UnweightedAggregationFactory`.

  Raises:
    TypeError: If arguments have the wrong type(s).
    ValueError: If arguments have invalid value(s).
  """
    _check_positive(num_bins, 'num_bins')
    _check_greater_equal(arity, 2, 'arity')
    _check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS,
                      'clip_mechanism')
    _check_positive(max_records_per_user, 'max_records_per_user')
    _check_membership(dp_mechanism, DP_MECHANISMS, 'dp_mechanism')
    _check_non_negative(noise_multiplier, 'noise_multiplier')
    _check_positive(expected_clients_per_round, 'expected_clients_per_round')
    _check_in_range(bits, 'bits', 1, 22)

    # Converts `max_records_per_user` to the corresponding norm bound according to
    # the chosen `clip_mechanism` and `dp_mechanism`.
    if dp_mechanism in ['central-gaussian', 'distributed-discrete-gaussian']:
        if clip_mechanism == 'sub-sampling':
            l2_norm_bound = max_records_per_user * math.sqrt(
                _tree_depth(num_bins, arity))
        elif clip_mechanism == 'distinct':
            # The following code block converts `max_records_per_user` to L2 norm
            # bound of the hierarchical histogram layer by layer. For the bottom
            # layer with only 0s and at most `max_records_per_user` 1s, the L2 norm
            # bound is `sqrt(max_records_per_user)`. For the second layer from bottom,
            # the worst case is only 0s and `max_records_per_user/2` 2s. And so on
            # until the root node. Another natural L2 norm bound on each layer is
            # `max_records_per_user` so we take the minimum between the two bounds.
            square_l2_norm_bound = 0.
            square_layer_l2_norm_bound = max_records_per_user
            for _ in range(_tree_depth(num_bins, arity)):
                square_l2_norm_bound += min(max_records_per_user**2,
                                            square_layer_l2_norm_bound)
                square_layer_l2_norm_bound *= arity
            l2_norm_bound = math.sqrt(square_l2_norm_bound)

    if not enable_secure_sum and dp_mechanism in DISTRIBUTED_DP_MECHANISMS:
        raise ValueError(f'When dp_mechanism is {DISTRIBUTED_DP_MECHANISMS}, '
                         'enable_secure_sum must be set to True to preserve '
                         'distributed DP.')

    # Build nested aggregtion factory from innermost to outermost.
    # 1. Sum factory. The most inner factory that sums the preprocessed records.
    # (1) If  `enable_secure_sum` is `False`, should be `SumFactory`.
    if not enable_secure_sum:
        nested_factory = sum_factory.SumFactory()
    else:
        # (2) If  `enable_secure_sum` is `True`, and `dp_mechanism` is 'no-noise' or
        # 'central-gaussian', the sum factory should be `SecureSumFactory`, with
        # a `upper_bound_threshold` of `max_records_per_user`. When `dp_mechanism`
        # is 'central-gaussian', use a float `SecureSumFactory` to be compatible
        # with `GaussianSumQuery`.
        if dp_mechanism in ['no-noise']:
            nested_factory = secure.SecureSumFactory(max_records_per_user)
        elif dp_mechanism in ['central-gaussian']:
            nested_factory = secure.SecureSumFactory(
                float(max_records_per_user))
        # (3) If `dp_mechanism` is in `DISTRIBUTED_DP_MECHANISMS`, should be
        #     `SecureSumFactory`. To preserve DP and avoid overflow, we have 4
        #    modular clips from nesting two modular clip aggregators:
        #    #1. outer-client: clips to [-2**(bits-1), 2**(bits-1))
        #        Bounds the client values.
        #    #2. inner-client: clips to [0, 2**bits)
        #        Similar to applying a two's complement to the values such that
        #        frequent values (post-rotation) are now near 0 (representing small
        #        positives) and 2**bits (small negatives). 0 also always map to 0,
        #        and we do not require another explicit value range shift from
        #        [-2**(bits-1), 2**(bits-1)] to [0, 2**bits] to make sure that
        #        values are compatible with SecAgg's mod m = 2**bits. This can be
        #        reverted at #4.
        #    #3. inner-server: clips to [0, 2**bits)
        #        Ensures the aggregated value range does not grow by
        #        `log2(expected_clients_per_round)`.
        #        NOTE: If underlying SecAgg is implemented using the new
        #        `tff.federated_secure_modular_sum()` operator with the same
        #        modular clipping range, then this would correspond to a no-op.
        #    #4. outer-server: clips to [-2**(bits-1), 2**(bits-1))
        #        Keeps aggregated values centered near 0 out of the logical SecAgg
        #        black box for outer aggregators.
        elif dp_mechanism in ['distributed-discrete-gaussian']:
            # TODO(b/196312838): Please add scaling to the distributed case once we
            # have a stable guideline for setting scaling factor to improve
            # performance and avoid overflow. The below test is to make sure that
            # modular clipping happens with small probability so the accuracy of the
            # result won't be harmed. However, if the number of clients exceeds
            # `expected_clients_per_round`, overflow still might happen. It is the
            # caller's responsibility to carefully choose `bits` according to system
            # details to avoid overflow or performance degradation.
            if bits < math.log2(4 * math.sqrt(expected_clients_per_round) *
                                noise_multiplier * l2_norm_bound +
                                expected_clients_per_round *
                                max_records_per_user) + 1:
                raise ValueError(
                    f'The selected bit-width ({bits}) is too small for the '
                    f'given parameters (expected_clients_per_round = '
                    f'{expected_clients_per_round}, max_records_per_user = '******'{max_records_per_user}, noise_multiplier = '
                    f'{noise_multiplier}) and will harm the accuracy of the '
                    f'result. Please decrease the '
                    f'`expected_clients_per_round` / `max_records_per_user` '
                    f'/ `noise_multiplier`, or increase `bits`.')
            nested_factory = secure.SecureSumFactory(
                upper_bound_threshold=2**bits - 1, lower_bound_threshold=0)
            nested_factory = modular_clipping_factory.ModularClippingSumFactory(
                clip_range_lower=0,
                clip_range_upper=2**bits,
                inner_agg_factory=nested_factory)
            nested_factory = modular_clipping_factory.ModularClippingSumFactory(
                clip_range_lower=-2**(bits - 1),
                clip_range_upper=2**(bits - 1),
                inner_agg_factory=nested_factory)

    # 2. DP operations.
    # Constructs `DifferentiallyPrivateFactory` according to the chosen
    # `dp_mechanism`.
    if dp_mechanism == 'central-gaussian':
        query = tfp.TreeRangeSumQuery.build_central_gaussian_query(
            l2_norm_bound, noise_multiplier * l2_norm_bound, arity)
        # If the inner `DifferentiallyPrivateFactory` uses `GaussianSumQuery`, then
        # the record is casted to `tf.float32` before feeding to the DP factory.
        cast_to_float = True
    elif dp_mechanism == 'distributed-discrete-gaussian':
        query = tfp.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
            l2_norm_bound, noise_multiplier * l2_norm_bound /
            math.sqrt(expected_clients_per_round), arity)
        # If the inner `DifferentiallyPrivateFactory` uses
        # `DistributedDiscreteGaussianQuery`, then the record is kept as `tf.int32`
        # before feeding to the DP factory.
        cast_to_float = False
    elif dp_mechanism == 'no-noise':
        inner_query = tfp.NoPrivacySumQuery()
        query = tfp.TreeRangeSumQuery(arity=arity, inner_query=inner_query)
        # If the inner `DifferentiallyPrivateFactory` uses `NoPrivacyQuery`, then
        # the record is kept as `tf.int32` before feeding to the DP factory.
        cast_to_float = False
    else:
        raise ValueError('Unexpected dp_mechanism.')
    nested_factory = differential_privacy.DifferentiallyPrivateFactory(
        query, nested_factory)

    # 3. Clip as specified by `clip_mechanism`.
    nested_factory = clipping_factory.HistogramClippingSumFactory(
        clip_mechanism=clip_mechanism,
        max_records_per_user=max_records_per_user,
        inner_agg_factory=nested_factory,
        cast_to_float=cast_to_float)

    return nested_factory