def test_raise_on_invalid_clip_type(self, value): with self.assertRaises(TypeError): modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=value, clip_range_upper=2) with self.assertRaises(TypeError): modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2, clip_range_upper=value)
def test_type_properties_simple(self): value_type = computation_types.to_type((tf.int32, (2,))) agg_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2, clip_range_upper=2) process = agg_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_init_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to(expected_init_type)) expected_measurements_type = collections.OrderedDict(modclip=()) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_clip_sum(self, clip_range_lower, clip_range_upper, client_data, expected_sum): agg_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower, clip_range_upper) value_type = computation_types.to_type(tf.int32) process = agg_factory.create(value_type) state = process.initialize() output = process.next(state, client_data) self.assertEqual(output.result, expected_sum)
def test_clip_sum_struct(self, clip_range_lower, clip_range_upper, client_data, expected_sum): agg_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower, clip_range_upper) value_type = computation_types.to_type((tf.int32, (2,))) process = agg_factory.create(value_type) state = process.initialize() client_tensor_data = [ tf.constant(v, dtype=tf.int32, shape=(2,)) for v in client_data ] output = process.next(state, client_tensor_data) self.assertAllClose( tf.constant(expected_sum, dtype=tf.int32, shape=(2,)), output.result)
def test_component_tensor_dtypes_raise_on(self, value_type): agg_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2, clip_range_upper=2) value_type = computation_types.to_type(value_type) with self.assertRaisesRegex(TypeError, 'must all be integers'): agg_factory.create(value_type)
def test_tff_value_types_raise_on(self, value_type): agg_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2, clip_range_upper=2) value_type = computation_types.to_type(value_type) with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'): agg_factory.create(value_type)
def test_raise_on_clip_range(self, lower, upper): with self.assertRaises(ValueError): modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=lower, clip_range_upper=upper)
def create_hierarchical_histogram_aggregation_factory( num_bins: int, arity: int = 2, clip_mechanism: str = 'sub-sampling', max_records_per_user: int = 10, dp_mechanism: str = 'no-noise', noise_multiplier: float = 0.0, expected_clients_per_round: int = 10, bits: int = 22, enable_secure_sum: bool = True): """Creates hierarchical histogram aggregation factory. Hierarchical histogram factory is constructed by composing 3 aggregation factories. (1) The inner-most factory is `SumFactory`. (2) The middle factory is `DifferentiallyPrivateFactory` whose inner query is `TreeRangeSumQuery`. This factory 1) takes in a clipped histogram, constructs the hierarchical histogram and checks the norm bound of the hierarchical histogram at clients, 2) adds noise either at clients or at server according to `dp_mechanism`. (3) The outer-most factory is `HistogramClippingSumFactory` which clips the input histogram to bound each user's contribution. Args: num_bins: An `int` representing the input histogram size. arity: An `int` representing the branching factor of the tree. Defaults to 2. clip_mechanism: A `str` representing the clipping mechanism. Currently supported mechanisms are - 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user` records without replacement from the client dataset. - 'distinct': Uniquify client dataset and uniformly sample up to `max_records_per_user` records without replacement from it. max_records_per_user: An `int` representing the maximum of records each user can include in their local histogram. Defaults to 10. dp_mechanism: A `str` representing the differentially private mechanism to use. Currently supported mechanisms are - 'no-noise': (Default) Tree aggregation mechanism without noise. - 'central-gaussian': Tree aggregation with central Gaussian mechanism. - 'distributed-discrete-gaussian': Tree aggregation mechanism with distributed discrete Gaussian mechanism in "The Distributed Discrete Gaussian Mechanism for Federated Learning with Secure Aggregation. Peter Kairouz, Ziyu Liu, Thomas Steinke". noise_multiplier: A `float` specifying the noise multiplier (central noise stddev / L2 clip norm) for model updates. Only needed when `dp_mechanism` is not 'no-noise'. Defaults to 0.0. expected_clients_per_round: An `int` specifying the lower bound of the expected number of clients. Only needed when `dp_mechanism` is 'distributed-discrete-gaussian. Defaults to 10. bits: A positive integer specifying the communication bit-width B (where 2**B will be the field size for SecAgg operations). Only needed when `dp_mechanism` is 'distributed-discrete-gaussian'. Please read the below precautions carefully and set `bits` accordingly. Otherwise, unexpected overflow or accuracy degradation might happen. (1) Should be in the inclusive range [1, 22] to avoid overflow inside secure aggregation; (2) Should be at least as large as `log2(4 * sqrt(expected_clients_per_round)* noise_multiplier * l2_norm_bound + expected_clients_per_round * max_records_per_user) + 1` to avoid accuracy degradation caused by frequent modular clipping; (3) If the number of clients exceed `expected_clients_per_round`, overflow might happen. enable_secure_sum: Whether to aggregate client's update by secure sum or not. Defaults to `True`. When `dp_mechanism` is set to `'distributed-discrete-gaussian'`, `enable_secure_sum` must be `True`. Returns: `tff.aggregators.UnweightedAggregationFactory`. Raises: TypeError: If arguments have the wrong type(s). ValueError: If arguments have invalid value(s). """ _check_positive(num_bins, 'num_bins') _check_greater_equal(arity, 2, 'arity') _check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS, 'clip_mechanism') _check_positive(max_records_per_user, 'max_records_per_user') _check_membership(dp_mechanism, DP_MECHANISMS, 'dp_mechanism') _check_non_negative(noise_multiplier, 'noise_multiplier') _check_positive(expected_clients_per_round, 'expected_clients_per_round') _check_in_range(bits, 'bits', 1, 22) # Converts `max_records_per_user` to the corresponding norm bound according to # the chosen `clip_mechanism` and `dp_mechanism`. if dp_mechanism in ['central-gaussian', 'distributed-discrete-gaussian']: if clip_mechanism == 'sub-sampling': l2_norm_bound = max_records_per_user * math.sqrt( _tree_depth(num_bins, arity)) elif clip_mechanism == 'distinct': # The following code block converts `max_records_per_user` to L2 norm # bound of the hierarchical histogram layer by layer. For the bottom # layer with only 0s and at most `max_records_per_user` 1s, the L2 norm # bound is `sqrt(max_records_per_user)`. For the second layer from bottom, # the worst case is only 0s and `max_records_per_user/2` 2s. And so on # until the root node. Another natural L2 norm bound on each layer is # `max_records_per_user` so we take the minimum between the two bounds. square_l2_norm_bound = 0. square_layer_l2_norm_bound = max_records_per_user for _ in range(_tree_depth(num_bins, arity)): square_l2_norm_bound += min(max_records_per_user**2, square_layer_l2_norm_bound) square_layer_l2_norm_bound *= arity l2_norm_bound = math.sqrt(square_l2_norm_bound) if not enable_secure_sum and dp_mechanism in DISTRIBUTED_DP_MECHANISMS: raise ValueError(f'When dp_mechanism is {DISTRIBUTED_DP_MECHANISMS}, ' 'enable_secure_sum must be set to True to preserve ' 'distributed DP.') # Build nested aggregtion factory from innermost to outermost. # 1. Sum factory. The most inner factory that sums the preprocessed records. # (1) If `enable_secure_sum` is `False`, should be `SumFactory`. if not enable_secure_sum: nested_factory = sum_factory.SumFactory() else: # (2) If `enable_secure_sum` is `True`, and `dp_mechanism` is 'no-noise' or # 'central-gaussian', the sum factory should be `SecureSumFactory`, with # a `upper_bound_threshold` of `max_records_per_user`. When `dp_mechanism` # is 'central-gaussian', use a float `SecureSumFactory` to be compatible # with `GaussianSumQuery`. if dp_mechanism in ['no-noise']: nested_factory = secure.SecureSumFactory(max_records_per_user) elif dp_mechanism in ['central-gaussian']: nested_factory = secure.SecureSumFactory( float(max_records_per_user)) # (3) If `dp_mechanism` is in `DISTRIBUTED_DP_MECHANISMS`, should be # `SecureSumFactory`. To preserve DP and avoid overflow, we have 4 # modular clips from nesting two modular clip aggregators: # #1. outer-client: clips to [-2**(bits-1), 2**(bits-1)) # Bounds the client values. # #2. inner-client: clips to [0, 2**bits) # Similar to applying a two's complement to the values such that # frequent values (post-rotation) are now near 0 (representing small # positives) and 2**bits (small negatives). 0 also always map to 0, # and we do not require another explicit value range shift from # [-2**(bits-1), 2**(bits-1)] to [0, 2**bits] to make sure that # values are compatible with SecAgg's mod m = 2**bits. This can be # reverted at #4. # #3. inner-server: clips to [0, 2**bits) # Ensures the aggregated value range does not grow by # `log2(expected_clients_per_round)`. # NOTE: If underlying SecAgg is implemented using the new # `tff.federated_secure_modular_sum()` operator with the same # modular clipping range, then this would correspond to a no-op. # #4. outer-server: clips to [-2**(bits-1), 2**(bits-1)) # Keeps aggregated values centered near 0 out of the logical SecAgg # black box for outer aggregators. elif dp_mechanism in ['distributed-discrete-gaussian']: # TODO(b/196312838): Please add scaling to the distributed case once we # have a stable guideline for setting scaling factor to improve # performance and avoid overflow. The below test is to make sure that # modular clipping happens with small probability so the accuracy of the # result won't be harmed. However, if the number of clients exceeds # `expected_clients_per_round`, overflow still might happen. It is the # caller's responsibility to carefully choose `bits` according to system # details to avoid overflow or performance degradation. if bits < math.log2(4 * math.sqrt(expected_clients_per_round) * noise_multiplier * l2_norm_bound + expected_clients_per_round * max_records_per_user) + 1: raise ValueError( f'The selected bit-width ({bits}) is too small for the ' f'given parameters (expected_clients_per_round = ' f'{expected_clients_per_round}, max_records_per_user = '******'{max_records_per_user}, noise_multiplier = ' f'{noise_multiplier}) and will harm the accuracy of the ' f'result. Please decrease the ' f'`expected_clients_per_round` / `max_records_per_user` ' f'/ `noise_multiplier`, or increase `bits`.') nested_factory = secure.SecureSumFactory( upper_bound_threshold=2**bits - 1, lower_bound_threshold=0) nested_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=0, clip_range_upper=2**bits, inner_agg_factory=nested_factory) nested_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2**(bits - 1), clip_range_upper=2**(bits - 1), inner_agg_factory=nested_factory) # 2. DP operations. # Constructs `DifferentiallyPrivateFactory` according to the chosen # `dp_mechanism`. if dp_mechanism == 'central-gaussian': query = tfp.TreeRangeSumQuery.build_central_gaussian_query( l2_norm_bound, noise_multiplier * l2_norm_bound, arity) # If the inner `DifferentiallyPrivateFactory` uses `GaussianSumQuery`, then # the record is casted to `tf.float32` before feeding to the DP factory. cast_to_float = True elif dp_mechanism == 'distributed-discrete-gaussian': query = tfp.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( l2_norm_bound, noise_multiplier * l2_norm_bound / math.sqrt(expected_clients_per_round), arity) # If the inner `DifferentiallyPrivateFactory` uses # `DistributedDiscreteGaussianQuery`, then the record is kept as `tf.int32` # before feeding to the DP factory. cast_to_float = False elif dp_mechanism == 'no-noise': inner_query = tfp.NoPrivacySumQuery() query = tfp.TreeRangeSumQuery(arity=arity, inner_query=inner_query) # If the inner `DifferentiallyPrivateFactory` uses `NoPrivacyQuery`, then # the record is kept as `tf.int32` before feeding to the DP factory. cast_to_float = False else: raise ValueError('Unexpected dp_mechanism.') nested_factory = differential_privacy.DifferentiallyPrivateFactory( query, nested_factory) # 3. Clip as specified by `clip_mechanism`. nested_factory = clipping_factory.HistogramClippingSumFactory( clip_mechanism=clip_mechanism, max_records_per_user=max_records_per_user, inner_agg_factory=nested_factory, cast_to_float=cast_to_float) return nested_factory