def test_clip_type_properties_with_clipped_count_agg_factory( self, value_type): factory = robust.clipping_factory( clipping_norm=1.0, inner_agg_factory=sum_factory.SumFactory(), clipped_count_sum_factory=aggregator_test_utils.SumPlusOneFactory( )) value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(clipping_norm=(), inner_agg=(), clipped_count_agg=tf.int32)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(clipping=(), clipping_norm=robust.NORM_TF_TYPE, clipped_count=robust.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def _default_clipping( inner_factory: factory.AggregationFactory ) -> factory.AggregationFactory: """The default adaptive clipping wrapper.""" # Adapts relatively quickly to a moderately high norm. clipping_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise( initial_estimate=1.0, target_quantile=0.8, learning_rate=0.2) return robust.clipping_factory(clipping_norm, inner_factory)
def _default_clipping( inner_factory: factory.AggregationFactory, secure_estimation: bool = False) -> factory.AggregationFactory: """The default adaptive clipping wrapper.""" # Adapts relatively quickly to a moderately high norm. clipping_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise( initial_estimate=1.0, target_quantile=0.8, learning_rate=0.2, secure_estimation=secure_estimation) if secure_estimation: secure_count_factory = secure.SecureSumFactory(upper_bound_threshold=1, lower_bound_threshold=0) return robust.clipping_factory( clipping_norm, inner_factory, clipped_count_sum_factory=secure_count_factory) else: return robust.clipping_factory(clipping_norm, inner_factory)
def _clipped_sum(clip=2.0): return robust.clipping_factory(clip, sum_factory.SumFactory())
def _clipped_mean(clip=2.0): return robust.clipping_factory(clip, mean.MeanFactory())
def test_type_properties(self, value_type, mechanism): ddp_factory = _make_test_factory(mechanism=mechanism) self.assertIsInstance(ddp_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = ddp_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # The state is a nested object with component factory states. Construct # test factories directly and compare the signatures. modsum_f = secure.SecureModularSumFactory(2**15, True) if mechanism == 'distributed_dgauss': dp_query = tfp.DistributedDiscreteGaussianSumQuery( l2_norm_bound=10.0, local_stddev=10.0) else: dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0, l2_norm_bound=10.0, local_stddev=10.0) dp_f = differential_privacy.DifferentiallyPrivateFactory( dp_query, modsum_f) discrete_f = discretization.DiscretizationFactory(dp_f) l2clip_f = robust.clipping_factory(clipping_norm=10.0, inner_agg_factory=discrete_f) rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f) expected_process = concat.concat_factory(rot_f).create(value_type) # Check init_fn/state. expected_init_type = expected_process.initialize.type_signature expected_state_type = expected_init_type.result actual_init_type = process.initialize.type_signature self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type)) # Check next_fn/measurements. tensor2type = type_conversions.type_from_tensors discrete_state = discrete_f.create( computation_types.to_type(tf.float32)).initialize() dp_query_state = dp_query.initial_global_state() dp_query_metrics_type = tensor2type( dp_query.derive_metrics(dp_query_state)) expected_measurements_type = collections.OrderedDict( l2_clip=robust.NORM_TF_TYPE, scale_factor=tensor2type(discrete_state['scale_factor']), scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound), scaled_local_stddev=tensor2type(dp_query_state.local_stddev), actual_num_clients=tf.int32, padded_dim=tf.int32, dp_query_metrics=dp_query_metrics_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) actual_next_type = process.next.type_signature self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def _build_aggregation_factory(self): central_stddev = self._value_noise_mult * self._initial_l2_clip local_stddev = central_stddev / math.sqrt(self._num_clients) # Ensure dim is at least 1 only for computing DDP parameters. self._client_dim = max(1, self._client_dim) if self._rotation_type == 'hd': # Hadamard transform requires dimension to be powers of 2. self._padded_dim = 2**math.ceil(math.log2(self._client_dim)) rotation_factory = rotation.HadamardTransformFactory else: # DFT pads at most 1 zero. self._padded_dim = math.ceil(self._client_dim / 2.0) * 2 rotation_factory = rotation.DiscreteFourierTransformFactory scale = _heuristic_scale_factor(local_stddev, self._initial_l2_clip, self._bits, self._num_clients, self._padded_dim, self._k_stddevs).numpy() # Very large scales could lead to overflows and are not as helpful for # utility. See comment above for more details. scale = min(scale, MAX_SCALE_FACTOR) if scale <= 1: warnings.warn( f'The selected scale_factor {scale} <= 1. This may lead to' f'substantial quantization errors. Consider increasing' f'the bit-width (currently {self._bits}) or decreasing the' f'expected number of clients per round (currently ' f'{self._num_clients}).') # The procedure for obtaining inflated L2 bound assumes eager TF execution # and can be rewritten with NumPy if needed. inflated_l2 = discretization.inflated_l2_norm_bound( l2_norm_bound=self._initial_l2_clip, gamma=1.0 / scale, beta=self._beta, dim=self._padded_dim).numpy() # Add small leeway on norm bounds to gracefully allow numerical errors. # Specifically, the norm thresholds are computed directly from the specified # parameters in Python and will be checked right before noising; on the # other hand, the actual norm of the record (to be measured at noising time) # can possibly be (negligibly) higher due to the float32 arithmetic after # the conditional rounding (thus failing the check). While we have mitigated # this by sharing the computation for the inflated norm bound from # quantization, adding a leeway makes the execution more robust (it does not # need to abort should any precision issues happen) while not affecting the # correctness if privacy accounting is done based on the norm bounds at the # DPQuery/DPFactory (which incorporates the leeway). scaled_inflated_l2 = (inflated_l2 + 1e-5) * scale # Since values are scaled and rounded to integers, we have L1 <= L2^2 # on top of the general of L1 <= sqrt(d) * L2. scaled_l1 = math.ceil( scaled_inflated_l2 * min(math.sqrt(self._padded_dim), scaled_inflated_l2)) # Build nested aggregtion factory. # 1. Secure Aggregation. In particular, we have 4 modular clips from # nesting two modular clip aggregators: # #1. outer-client: clips to [-2^(b-1), 2^(b-1)] # Bounds the client values (with limited effect as scaling was # chosen such that `num_clients` is taken into account). # #2. inner-client: clips to [0, 2^b] # Similar to applying a two's complement to the values such that # frequent values (post-rotation) are now near 0 (representing small # positives) and 2^b (small negatives). 0 also always map to 0, and # we do not require another explicit value range shift from # [-2^(b-1), 2^(b-1)] to [0, 2^b] to make sure that values are # compatible with SecAgg's mod m = 2^b. This can be reverted at #4. # #3. inner-server: clips to [0, 2^b] # Ensures the aggregated value range does not grow by log_2(n). # NOTE: If underlying SecAgg is implemented using the new # `tff.federated_secure_modular_sum()` operator with the same # modular clipping range, then this would correspond to a no-op. # #4. outer-server: clips to [-2^(b-1), 2^(b-1)] # Keeps aggregated values centered near 0 out of the logical SecAgg # black box for outer aggregators. # Note that the scaling factor and the bit-width are chosen such that # the number of clients to aggregate is taken into account. nested_factory = secure.SecureSumFactory( upper_bound_threshold=2**self._bits - 1, lower_bound_threshold=0) nested_factory = modular_clipping.ModularClippingSumFactory( clip_range_lower=0, clip_range_upper=2**self._bits, inner_agg_factory=nested_factory) nested_factory = modular_clipping.ModularClippingSumFactory( clip_range_lower=-(2**(self._bits - 1)), clip_range_upper=2**(self._bits - 1), inner_agg_factory=nested_factory) # 2. DP operations. DP params are in the scaled domain (post-quantization). if self._mechanism == 'distributed_dgauss': dp_query = tfp.DistributedDiscreteGaussianSumQuery( l2_norm_bound=scaled_inflated_l2, local_stddev=local_stddev * scale) else: dp_query = tfp.DistributedSkellamSumQuery( l1_norm_bound=scaled_l1, l2_norm_bound=scaled_inflated_l2, local_stddev=local_stddev * scale) nested_factory = differential_privacy.DifferentiallyPrivateFactory( query=dp_query, record_aggregation_factory=nested_factory) # 3. Discretization operations. This appropriately quantizes the inputs. nested_factory = discretization.DiscretizationFactory( inner_agg_factory=nested_factory, scale_factor=scale, stochastic=True, beta=self._beta, prior_norm_bound=self._initial_l2_clip) # 4. L2 clip, possibly adaptively with a `tff.templates.EstimationProcess`. nested_factory = robust.clipping_factory( clipping_norm=self._l2_clip, inner_agg_factory=nested_factory, clipped_count_sum_factory=secure.SecureSumFactory( upper_bound_threshold=1, lower_bound_threshold=0)) # 5. Flattening to improve quantization and reduce modular wrapping. nested_factory = rotation_factory(inner_agg_factory=nested_factory) # 6. Concat the input structure into a single vector. nested_factory = concat.concat_factory( inner_agg_factory=nested_factory) return nested_factory