def test_iterative_process_fails_with_dp_agg_and_none_client_weighting(self): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] # No values should be changed, but working with inf directly zeroes out all # updates. Preferring very large value, but one that can be handled in # multiplication/division gaussian_sum_query = tfp.GaussianSumQuery(l2_norm_clip=1e10, stddev=0) dp_sum_factory = differential_privacy.DifferentiallyPrivateFactory( query=gaussian_sum_query, record_aggregation_factory=sum_factory.SumFactory()) dp_mean_factory = _DPMean(dp_sum_factory) with self.assertRaisesRegex(ValueError, 'unweighted aggregator'): training_process.build_training_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, server_optimizer_fn=_get_keras_optimizer_fn(0.01), client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0), aggregation_factory=dp_mean_factory, client_weighting=None, dataset_split_fn=reconstruction_utils.simple_dataset_split_fn)
def test_simple_sum(self): factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query) value_type = computation_types.to_type(tf.float32) process = factory_.create(value_type) # The test query has clip 1.0 and no noise, so this computes clipped sum. state = process.initialize() client_data = [0.5, 1.0, 1.5] output = process.next(state, client_data) self.assertAllClose(2.5, output.result)
def create_central_hierarchical_histogram_factory( stddev: float = 0.0, arity: int = 2, max_records_per_user: int = 10, secure_sum: bool = False): """Creates aggregator for hierarchical histograms with differential privacy. Args: stddev: The standard deviation of noise added to each node of the central tree. arity: The branching factor of the tree. max_records_per_user: The maximum of records each user can upload in their local histogram. secure_sum: A boolean deciding whether to use secure aggregation. Defaults to `False`. Returns: `tff.aggregators.UnWeightedAggregationFactory`. Raises: `ValueError`: If 'stddev < 0', `arity < 2`, `max_records_per_user < 1` or `inner_agg_factory` is illegal. """ if stddev < 0: raise ValueError(f"Standard deviation should be greater than zero." f"stddev={stddev} is given.") if arity < 2: raise ValueError(f"Arity should be at least 2." f"arity={arity} is given.") if max_records_per_user < 1: raise ValueError( f"Maximum records per user should be at least 1." f"max_records_per_user={max_records_per_user} is given.") central_tree_agg_query = tfp.privacy.dp_query.tree_aggregation_query.CentralTreeSumQuery( stddev=stddev, arity=arity, l1_bound=max_records_per_user) if secure_sum: inner_agg_factory = secure.SecureSumFactory( upper_bound_threshold=float(max_records_per_user), lower_bound_threshold=0.) else: inner_agg_factory = sum_factory.SumFactory() return differential_privacy.DifferentiallyPrivateFactory( central_tree_agg_query, inner_agg_factory)
def test_inner_sum(self): factory_ = differential_privacy.DifferentiallyPrivateFactory( _test_dp_query, _test_inner_agg_factory) value_type = computation_types.to_type(tf.float32) process = factory_.create(value_type) # The test query has clip 1.0 and no noise, so this computes clipped sum. # Inner agg adds another 1.0 (post-clipping). state = process.initialize() self.assertAllEqual(0, state[1]) client_data = [0.5, 1.0, 1.5] output = process.next(state, client_data) self.assertAllEqual(1, output.state[1]) self.assertAllClose(3.5, output.result) self.assertAllEqual(test_utils.MEASUREMENT_CONSTANT, output.measurements['dp'])
def test_structure_sum(self): factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query) value_type = computation_types.to_type([tf.float32, tf.float32]) process = factory_.create(value_type) # The test query has clip 1.0 and no noise, so this computes clipped sum. state = process.initialize() # pyformat: disable client_data = [ [0.1, 0.2], # not clipped (norm < 1) [5 / 13, 12 / 13], # not clipped (norm == 1) [3.0, 4.0] # clipped to 3/5, 4/5 ] output = process.next(state, client_data) expected_result = [0.1 + 5 / 13 + 3 / 5, 0.2 + 12 / 13 + 4 / 5] # pyformat: enable self.assertAllClose(expected_result, output.result)
def test_type_properties(self, value_type, inner_agg_factory): factory_ = differential_privacy.DifferentiallyPrivateFactory( _test_dp_query, inner_agg_factory) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) query_state = _test_dp_query.initial_global_state() query_state_type = type_conversions.type_from_tensors(query_state) query_metrics_type = type_conversions.type_from_tensors( _test_dp_query.derive_metrics(query_state)) inner_state_type = tf.int32 if inner_agg_factory else () server_state_type = computation_types.at_server( (query_state_type, inner_state_type)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) inner_measurements_type = tf.int32 if inner_agg_factory else () expected_measurements_type = computation_types.at_server( collections.OrderedDict( dp_query_metrics=query_metrics_type, dp=inner_measurements_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_adaptive_query(self): query = tfp.QuantileAdaptiveClipSumQuery(initial_l2_norm_clip=1.0, noise_multiplier=0.0, target_unclipped_quantile=1.0, learning_rate=1.0, clipped_count_stddev=0.0, expected_num_records=3.0, geometric_update=False) factory_ = differential_privacy.DifferentiallyPrivateFactory(query) value_type = computation_types.to_type(tf.float32) process = factory_.create(value_type) state = process.initialize() client_data = [0.5, 1.5, 2.0] # Two clipped on first round. expected_result = 0.5 + 1.0 + 1.0 output = process.next(state, client_data) self.assertAllClose(expected_result, output.result) # Clip is increased by 2/3 to 5/3. expected_result = 0.5 + 1.5 + 5 / 3 output = process.next(output.state, client_data) self.assertAllClose(expected_result, output.result)
def test_incorrect_value_type_raises(self, bad_value_type): factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query) with self.assertRaises(TypeError): factory_.create(bad_value_type)
def test_init_non_agg_factory_raises(self): with self.assertRaises(TypeError): differential_privacy.DifferentiallyPrivateFactory(_test_dp_query, 'not an agg factory')
def test_type_properties(self, value_type, mechanism): ddp_factory = _make_test_factory(mechanism=mechanism) self.assertIsInstance(ddp_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = ddp_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # The state is a nested object with component factory states. Construct # test factories directly and compare the signatures. modsum_f = secure.SecureModularSumFactory(2**15, True) if mechanism == 'distributed_dgauss': dp_query = tfp.DistributedDiscreteGaussianSumQuery( l2_norm_bound=10.0, local_stddev=10.0) else: dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0, l2_norm_bound=10.0, local_stddev=10.0) dp_f = differential_privacy.DifferentiallyPrivateFactory( dp_query, modsum_f) discrete_f = discretization.DiscretizationFactory(dp_f) l2clip_f = robust.clipping_factory(clipping_norm=10.0, inner_agg_factory=discrete_f) rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f) expected_process = concat.concat_factory(rot_f).create(value_type) # Check init_fn/state. expected_init_type = expected_process.initialize.type_signature expected_state_type = expected_init_type.result actual_init_type = process.initialize.type_signature self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type)) # Check next_fn/measurements. tensor2type = type_conversions.type_from_tensors discrete_state = discrete_f.create( computation_types.to_type(tf.float32)).initialize() dp_query_state = dp_query.initial_global_state() dp_query_metrics_type = tensor2type( dp_query.derive_metrics(dp_query_state)) expected_measurements_type = collections.OrderedDict( l2_clip=robust.NORM_TF_TYPE, scale_factor=tensor2type(discrete_state['scale_factor']), scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound), scaled_local_stddev=tensor2type(dp_query_state.local_stddev), actual_num_clients=tf.int32, padded_dim=tf.int32, dp_query_metrics=dp_query_metrics_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) actual_next_type = process.next.type_signature self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def test_execution_with_custom_dp_query(self): client_data = create_emnist_client_data() train_data = [client_data(), client_data()] def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] # No values should be changed, but working with inf directly zeroes out all # updates. Preferring very large value, but one that can be handled in # multiplication/division gaussian_sum_query = tfp.GaussianSumQuery(l2_norm_clip=1e10, stddev=0) dp_sum_factory = differential_privacy.DifferentiallyPrivateFactory( query=gaussian_sum_query, record_aggregation_factory=sum_factory.SumFactory()) dp_mean_factory = _DPMean(dp_sum_factory) # Disable reconstruction via 0 learning rate to ensure post-recon loss # matches exact expectations round 0 and decreases by the next round. trainer = training_process.build_training_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, server_optimizer_fn=_get_keras_optimizer_fn(0.01), client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0), aggregation_factory=dp_mean_factory, dataset_split_fn=reconstruction_utils.simple_dataset_split_fn, client_weighting=client_weight_lib.ClientWeighting.UNIFORM, ) state = trainer.initialize() outputs = [] states = [] for _ in range(2): state, output = trainer.next(state, train_data) outputs.append(output) states.append(state) # All weights and biases are initialized to 0, so initial logits are all 0 # and softmax probabilities are uniform over 10 classes. So negative log # likelihood is -ln(1/10). This is on expectation, so increase tolerance. self.assertAllClose( outputs[0]['train']['loss'], tf.math.log(10.0), rtol=1e-4) self.assertLess(outputs[1]['train']['loss'], outputs[0]['train']['loss']) self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable) # Expect 6 reconstruction examples, 6 training examples. Only training # included in metrics. self.assertEqual(outputs[0]['train']['num_examples_total'], 6.0) self.assertEqual(outputs[1]['train']['num_examples_total'], 6.0) # Expect 4 reconstruction batches and 4 training batches. Only training # included in metrics. self.assertEqual(outputs[0]['train']['num_batches_total'], 4.0) self.assertEqual(outputs[1]['train']['num_batches_total'], 4.0)
def create_hierarchical_histogram_aggregation_factory( num_bins: int, arity: int = 2, clip_mechanism: str = 'sub-sampling', max_records_per_user: int = 10, dp_mechanism: str = 'no-noise', noise_multiplier: float = 0.0, expected_clients_per_round: int = 10, bits: int = 22, enable_secure_sum: bool = True): """Creates hierarchical histogram aggregation factory. Hierarchical histogram factory is constructed by composing 3 aggregation factories. (1) The inner-most factory is `SumFactory`. (2) The middle factory is `DifferentiallyPrivateFactory` whose inner query is `TreeRangeSumQuery`. This factory 1) takes in a clipped histogram, constructs the hierarchical histogram and checks the norm bound of the hierarchical histogram at clients, 2) adds noise either at clients or at server according to `dp_mechanism`. (3) The outer-most factory is `HistogramClippingSumFactory` which clips the input histogram to bound each user's contribution. Args: num_bins: An `int` representing the input histogram size. arity: An `int` representing the branching factor of the tree. Defaults to 2. clip_mechanism: A `str` representing the clipping mechanism. Currently supported mechanisms are - 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user` records without replacement from the client dataset. - 'distinct': Uniquify client dataset and uniformly sample up to `max_records_per_user` records without replacement from it. max_records_per_user: An `int` representing the maximum of records each user can include in their local histogram. Defaults to 10. dp_mechanism: A `str` representing the differentially private mechanism to use. Currently supported mechanisms are - 'no-noise': (Default) Tree aggregation mechanism without noise. - 'central-gaussian': Tree aggregation with central Gaussian mechanism. - 'distributed-discrete-gaussian': Tree aggregation mechanism with distributed discrete Gaussian mechanism in "The Distributed Discrete Gaussian Mechanism for Federated Learning with Secure Aggregation. Peter Kairouz, Ziyu Liu, Thomas Steinke". noise_multiplier: A `float` specifying the noise multiplier (central noise stddev / L2 clip norm) for model updates. Only needed when `dp_mechanism` is not 'no-noise'. Defaults to 0.0. expected_clients_per_round: An `int` specifying the lower bound of the expected number of clients. Only needed when `dp_mechanism` is 'distributed-discrete-gaussian. Defaults to 10. bits: A positive integer specifying the communication bit-width B (where 2**B will be the field size for SecAgg operations). Only needed when `dp_mechanism` is 'distributed-discrete-gaussian'. Please read the below precautions carefully and set `bits` accordingly. Otherwise, unexpected overflow or accuracy degradation might happen. (1) Should be in the inclusive range [1, 22] to avoid overflow inside secure aggregation; (2) Should be at least as large as `log2(4 * sqrt(expected_clients_per_round)* noise_multiplier * l2_norm_bound + expected_clients_per_round * max_records_per_user) + 1` to avoid accuracy degradation caused by frequent modular clipping; (3) If the number of clients exceed `expected_clients_per_round`, overflow might happen. enable_secure_sum: Whether to aggregate client's update by secure sum or not. Defaults to `True`. When `dp_mechanism` is set to `'distributed-discrete-gaussian'`, `enable_secure_sum` must be `True`. Returns: `tff.aggregators.UnweightedAggregationFactory`. Raises: TypeError: If arguments have the wrong type(s). ValueError: If arguments have invalid value(s). """ _check_positive(num_bins, 'num_bins') _check_greater_equal(arity, 2, 'arity') _check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS, 'clip_mechanism') _check_positive(max_records_per_user, 'max_records_per_user') _check_membership(dp_mechanism, DP_MECHANISMS, 'dp_mechanism') _check_non_negative(noise_multiplier, 'noise_multiplier') _check_positive(expected_clients_per_round, 'expected_clients_per_round') _check_in_range(bits, 'bits', 1, 22) # Converts `max_records_per_user` to the corresponding norm bound according to # the chosen `clip_mechanism` and `dp_mechanism`. if dp_mechanism in ['central-gaussian', 'distributed-discrete-gaussian']: if clip_mechanism == 'sub-sampling': l2_norm_bound = max_records_per_user * math.sqrt( _tree_depth(num_bins, arity)) elif clip_mechanism == 'distinct': # The following code block converts `max_records_per_user` to L2 norm # bound of the hierarchical histogram layer by layer. For the bottom # layer with only 0s and at most `max_records_per_user` 1s, the L2 norm # bound is `sqrt(max_records_per_user)`. For the second layer from bottom, # the worst case is only 0s and `max_records_per_user/2` 2s. And so on # until the root node. Another natural L2 norm bound on each layer is # `max_records_per_user` so we take the minimum between the two bounds. square_l2_norm_bound = 0. square_layer_l2_norm_bound = max_records_per_user for _ in range(_tree_depth(num_bins, arity)): square_l2_norm_bound += min(max_records_per_user**2, square_layer_l2_norm_bound) square_layer_l2_norm_bound *= arity l2_norm_bound = math.sqrt(square_l2_norm_bound) if not enable_secure_sum and dp_mechanism in DISTRIBUTED_DP_MECHANISMS: raise ValueError(f'When dp_mechanism is {DISTRIBUTED_DP_MECHANISMS}, ' 'enable_secure_sum must be set to True to preserve ' 'distributed DP.') # Build nested aggregtion factory from innermost to outermost. # 1. Sum factory. The most inner factory that sums the preprocessed records. # (1) If `enable_secure_sum` is `False`, should be `SumFactory`. if not enable_secure_sum: nested_factory = sum_factory.SumFactory() else: # (2) If `enable_secure_sum` is `True`, and `dp_mechanism` is 'no-noise' or # 'central-gaussian', the sum factory should be `SecureSumFactory`, with # a `upper_bound_threshold` of `max_records_per_user`. When `dp_mechanism` # is 'central-gaussian', use a float `SecureSumFactory` to be compatible # with `GaussianSumQuery`. if dp_mechanism in ['no-noise']: nested_factory = secure.SecureSumFactory(max_records_per_user) elif dp_mechanism in ['central-gaussian']: nested_factory = secure.SecureSumFactory( float(max_records_per_user)) # (3) If `dp_mechanism` is in `DISTRIBUTED_DP_MECHANISMS`, should be # `SecureSumFactory`. To preserve DP and avoid overflow, we have 4 # modular clips from nesting two modular clip aggregators: # #1. outer-client: clips to [-2**(bits-1), 2**(bits-1)) # Bounds the client values. # #2. inner-client: clips to [0, 2**bits) # Similar to applying a two's complement to the values such that # frequent values (post-rotation) are now near 0 (representing small # positives) and 2**bits (small negatives). 0 also always map to 0, # and we do not require another explicit value range shift from # [-2**(bits-1), 2**(bits-1)] to [0, 2**bits] to make sure that # values are compatible with SecAgg's mod m = 2**bits. This can be # reverted at #4. # #3. inner-server: clips to [0, 2**bits) # Ensures the aggregated value range does not grow by # `log2(expected_clients_per_round)`. # NOTE: If underlying SecAgg is implemented using the new # `tff.federated_secure_modular_sum()` operator with the same # modular clipping range, then this would correspond to a no-op. # #4. outer-server: clips to [-2**(bits-1), 2**(bits-1)) # Keeps aggregated values centered near 0 out of the logical SecAgg # black box for outer aggregators. elif dp_mechanism in ['distributed-discrete-gaussian']: # TODO(b/196312838): Please add scaling to the distributed case once we # have a stable guideline for setting scaling factor to improve # performance and avoid overflow. The below test is to make sure that # modular clipping happens with small probability so the accuracy of the # result won't be harmed. However, if the number of clients exceeds # `expected_clients_per_round`, overflow still might happen. It is the # caller's responsibility to carefully choose `bits` according to system # details to avoid overflow or performance degradation. if bits < math.log2(4 * math.sqrt(expected_clients_per_round) * noise_multiplier * l2_norm_bound + expected_clients_per_round * max_records_per_user) + 1: raise ValueError( f'The selected bit-width ({bits}) is too small for the ' f'given parameters (expected_clients_per_round = ' f'{expected_clients_per_round}, max_records_per_user = '******'{max_records_per_user}, noise_multiplier = ' f'{noise_multiplier}) and will harm the accuracy of the ' f'result. Please decrease the ' f'`expected_clients_per_round` / `max_records_per_user` ' f'/ `noise_multiplier`, or increase `bits`.') nested_factory = secure.SecureSumFactory( upper_bound_threshold=2**bits - 1, lower_bound_threshold=0) nested_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=0, clip_range_upper=2**bits, inner_agg_factory=nested_factory) nested_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2**(bits - 1), clip_range_upper=2**(bits - 1), inner_agg_factory=nested_factory) # 2. DP operations. # Constructs `DifferentiallyPrivateFactory` according to the chosen # `dp_mechanism`. if dp_mechanism == 'central-gaussian': query = tfp.TreeRangeSumQuery.build_central_gaussian_query( l2_norm_bound, noise_multiplier * l2_norm_bound, arity) # If the inner `DifferentiallyPrivateFactory` uses `GaussianSumQuery`, then # the record is casted to `tf.float32` before feeding to the DP factory. cast_to_float = True elif dp_mechanism == 'distributed-discrete-gaussian': query = tfp.TreeRangeSumQuery.build_distributed_discrete_gaussian_query( l2_norm_bound, noise_multiplier * l2_norm_bound / math.sqrt(expected_clients_per_round), arity) # If the inner `DifferentiallyPrivateFactory` uses # `DistributedDiscreteGaussianQuery`, then the record is kept as `tf.int32` # before feeding to the DP factory. cast_to_float = False elif dp_mechanism == 'no-noise': inner_query = tfp.NoPrivacySumQuery() query = tfp.TreeRangeSumQuery(arity=arity, inner_query=inner_query) # If the inner `DifferentiallyPrivateFactory` uses `NoPrivacyQuery`, then # the record is kept as `tf.int32` before feeding to the DP factory. cast_to_float = False else: raise ValueError('Unexpected dp_mechanism.') nested_factory = differential_privacy.DifferentiallyPrivateFactory( query, nested_factory) # 3. Clip as specified by `clip_mechanism`. nested_factory = clipping_factory.HistogramClippingSumFactory( clip_mechanism=clip_mechanism, max_records_per_user=max_records_per_user, inner_agg_factory=nested_factory, cast_to_float=cast_to_float) return nested_factory
def _build_aggregation_factory(self): central_stddev = self._value_noise_mult * self._initial_l2_clip local_stddev = central_stddev / math.sqrt(self._num_clients) # Ensure dim is at least 1 only for computing DDP parameters. self._client_dim = max(1, self._client_dim) if self._rotation_type == 'hd': # Hadamard transform requires dimension to be powers of 2. self._padded_dim = 2**math.ceil(math.log2(self._client_dim)) rotation_factory = rotation.HadamardTransformFactory else: # DFT pads at most 1 zero. self._padded_dim = math.ceil(self._client_dim / 2.0) * 2 rotation_factory = rotation.DiscreteFourierTransformFactory scale = _heuristic_scale_factor(local_stddev, self._initial_l2_clip, self._bits, self._num_clients, self._padded_dim, self._k_stddevs).numpy() # Very large scales could lead to overflows and are not as helpful for # utility. See comment above for more details. scale = min(scale, MAX_SCALE_FACTOR) if scale <= 1: warnings.warn( f'The selected scale_factor {scale} <= 1. This may lead to' f'substantial quantization errors. Consider increasing' f'the bit-width (currently {self._bits}) or decreasing the' f'expected number of clients per round (currently ' f'{self._num_clients}).') # The procedure for obtaining inflated L2 bound assumes eager TF execution # and can be rewritten with NumPy if needed. inflated_l2 = discretization.inflated_l2_norm_bound( l2_norm_bound=self._initial_l2_clip, gamma=1.0 / scale, beta=self._beta, dim=self._padded_dim).numpy() # Add small leeway on norm bounds to gracefully allow numerical errors. # Specifically, the norm thresholds are computed directly from the specified # parameters in Python and will be checked right before noising; on the # other hand, the actual norm of the record (to be measured at noising time) # can possibly be (negligibly) higher due to the float32 arithmetic after # the conditional rounding (thus failing the check). While we have mitigated # this by sharing the computation for the inflated norm bound from # quantization, adding a leeway makes the execution more robust (it does not # need to abort should any precision issues happen) while not affecting the # correctness if privacy accounting is done based on the norm bounds at the # DPQuery/DPFactory (which incorporates the leeway). scaled_inflated_l2 = (inflated_l2 + 1e-5) * scale # Since values are scaled and rounded to integers, we have L1 <= L2^2 # on top of the general of L1 <= sqrt(d) * L2. scaled_l1 = math.ceil( scaled_inflated_l2 * min(math.sqrt(self._padded_dim), scaled_inflated_l2)) # Build nested aggregtion factory. # 1. Secure Aggregation. In particular, we have 4 modular clips from # nesting two modular clip aggregators: # #1. outer-client: clips to [-2^(b-1), 2^(b-1)] # Bounds the client values (with limited effect as scaling was # chosen such that `num_clients` is taken into account). # #2. inner-client: clips to [0, 2^b] # Similar to applying a two's complement to the values such that # frequent values (post-rotation) are now near 0 (representing small # positives) and 2^b (small negatives). 0 also always map to 0, and # we do not require another explicit value range shift from # [-2^(b-1), 2^(b-1)] to [0, 2^b] to make sure that values are # compatible with SecAgg's mod m = 2^b. This can be reverted at #4. # #3. inner-server: clips to [0, 2^b] # Ensures the aggregated value range does not grow by log_2(n). # NOTE: If underlying SecAgg is implemented using the new # `tff.federated_secure_modular_sum()` operator with the same # modular clipping range, then this would correspond to a no-op. # #4. outer-server: clips to [-2^(b-1), 2^(b-1)] # Keeps aggregated values centered near 0 out of the logical SecAgg # black box for outer aggregators. # Note that the scaling factor and the bit-width are chosen such that # the number of clients to aggregate is taken into account. nested_factory = secure.SecureSumFactory( upper_bound_threshold=2**self._bits - 1, lower_bound_threshold=0) nested_factory = modular_clipping.ModularClippingSumFactory( clip_range_lower=0, clip_range_upper=2**self._bits, inner_agg_factory=nested_factory) nested_factory = modular_clipping.ModularClippingSumFactory( clip_range_lower=-(2**(self._bits - 1)), clip_range_upper=2**(self._bits - 1), inner_agg_factory=nested_factory) # 2. DP operations. DP params are in the scaled domain (post-quantization). if self._mechanism == 'distributed_dgauss': dp_query = tfp.DistributedDiscreteGaussianSumQuery( l2_norm_bound=scaled_inflated_l2, local_stddev=local_stddev * scale) else: dp_query = tfp.DistributedSkellamSumQuery( l1_norm_bound=scaled_l1, l2_norm_bound=scaled_inflated_l2, local_stddev=local_stddev * scale) nested_factory = differential_privacy.DifferentiallyPrivateFactory( query=dp_query, record_aggregation_factory=nested_factory) # 3. Discretization operations. This appropriately quantizes the inputs. nested_factory = discretization.DiscretizationFactory( inner_agg_factory=nested_factory, scale_factor=scale, stochastic=True, beta=self._beta, prior_norm_bound=self._initial_l2_clip) # 4. L2 clip, possibly adaptively with a `tff.templates.EstimationProcess`. nested_factory = robust.clipping_factory( clipping_norm=self._l2_clip, inner_agg_factory=nested_factory, clipped_count_sum_factory=secure.SecureSumFactory( upper_bound_threshold=1, lower_bound_threshold=0)) # 5. Flattening to improve quantization and reduce modular wrapping. nested_factory = rotation_factory(inner_agg_factory=nested_factory) # 6. Concat the input structure into a single vector. nested_factory = concat.concat_factory( inner_agg_factory=nested_factory) return nested_factory