def test_type_signature_with_structure_of_tensors_and_bitwidths(self): np_array = np.ndarray(shape=(5, 37), dtype=np.int16) value = intrinsics.federated_value((np_array, np_array), placements.CLIENTS) bitwidth = (2, 2) result = intrinsics.federated_secure_sum_bitwidth(value, bitwidth) self.assert_value(result, '<int16[5,37],int16[5,37]>@SERVER')
def _sum_securely(self, value, upper_bound, lower_bound): """Securely sums `value` placed at CLIENTS.""" if self._config_mode == _Config.INT: value = intrinsics.federated_map( _client_shift, (value, intrinsics.federated_broadcast(upper_bound), intrinsics.federated_broadcast(lower_bound))) value = intrinsics.federated_secure_sum_bitwidth( value, self._secagg_bitwidth) num_summands = intrinsics.federated_secure_sum_bitwidth( _client_one(), bitwidth=1) value = intrinsics.federated_map( _server_shift, (value, lower_bound, num_summands)) return value elif self._config_mode == _Config.FLOAT: return primitives.secure_quantized_sum(value, lower_bound, upper_bound) else: raise ValueError( f'Unexpected internal config type: {self._config_mode}')
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) # No call to `federated_aggregate`. secure_update = intrinsics.federated_secure_sum_bitwidth(client_updates, 8) s6 = intrinsics.federated_zip([server_state, secure_update]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output
def _compute_measurements(self, upper_bound, lower_bound, value_max, value_min): """Creates measurements to be reported. All values are summed securely.""" is_max_clipped = intrinsics.federated_map( computations.tf_computation( lambda bound, value: tf.cast(bound < value, COUNT_TF_TYPE)), (intrinsics.federated_broadcast(upper_bound), value_max)) max_clipped_count = intrinsics.federated_secure_sum_bitwidth( is_max_clipped, bitwidth=1) is_min_clipped = intrinsics.federated_map( computations.tf_computation( lambda bound, value: tf.cast(bound > value, COUNT_TF_TYPE)), (intrinsics.federated_broadcast(lower_bound), value_min)) min_clipped_count = intrinsics.federated_secure_sum_bitwidth( is_min_clipped, bitwidth=1) measurements = collections.OrderedDict( secure_upper_clipped_count=max_clipped_count, secure_lower_clipped_count=min_clipped_count, secure_upper_threshold=upper_bound, secure_lower_threshold=lower_bound) return intrinsics.federated_zip(measurements)
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" # No call to `federated_map` with prepare. # No call to `federated_broadcast`. client_updates = intrinsics.federated_map(work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum_bitwidth( client_updates[1], 8) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum_bitwidth( client_updates[1], 8) new_server_state = intrinsics.federated_zip( [unsecure_update, secure_update]) # No call to `federated_map` with an `update` function. server_output = intrinsics.federated_value([], placements.SERVER) return new_server_state, server_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del server_state # Unused # No call to `federated_map` with prepare. # No call to `federated_broadcast`. client_updates = intrinsics.federated_map(work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum_bitwidth( client_updates[1], 8) s5 = intrinsics.federated_zip([unsecure_update, secure_update]) # Empty server state. new_server_state = intrinsics.federated_value([], placements.SERVER) server_output = intrinsics.federated_map(update, s5) return new_server_state, server_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) unused_client_input, to_broadcast = broadcast_and_return_arg_and_result(s2) client_input = intrinsics.federated_broadcast(to_broadcast) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum_bitwidth( client_updates[1], 8) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output
def next_computation(arg): """The logic of a single MapReduce processing round.""" s1 = arg[0] c1 = arg[1] s2 = intrinsics.federated_map(mrf.prepare, s1) c2 = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([c1, c2]) c4 = intrinsics.federated_map(mrf.work, c3) c5 = c4[0] c6 = c4[1] s3 = intrinsics.federated_aggregate(c5, mrf.zero(), mrf.accumulate, mrf.merge, mrf.report) s4 = intrinsics.federated_secure_sum_bitwidth(c6, mrf.bitwidth()) s5 = intrinsics.federated_zip([s3, s4]) s6 = intrinsics.federated_zip([s1, s5]) s7 = intrinsics.federated_map(mrf.update, s6) s8 = s7[0] s9 = s7[1] return s8, s9
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del server_state # Unused # No call to `federated_map` with prepare. # No call to `federated_broadcast`. client_updates = intrinsics.federated_map(work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_sum_bitwidth_update = intrinsics.federated_secure_sum_bitwidth( client_updates[1], bitwidth=8) secure_sum_update = intrinsics.federated_secure_sum(client_updates[2], max_input=1) secure_modular_sum_update = intrinsics.federated_secure_modular_sum( client_updates[3], modulus=8) new_server_state = intrinsics.federated_zip([ unsecure_update, secure_sum_bitwidth_update, secure_sum_update, secure_modular_sum_update ]) # No call to `federated_map` with an `update` function. server_output = intrinsics.federated_value([], placements.SERVER) return new_server_state, server_output
def _derive_measurements(self, agg_state, agg_measurements): _, discrete_state, dp_state = self._unpack_state(agg_state) l2_clip_metrics, _, dp_metrics = self._unpack_measurements( agg_measurements) dp_query_state, _ = dp_state actual_num_clients = intrinsics.federated_secure_sum_bitwidth( intrinsics.federated_value(1, placements.CLIENTS), bitwidth=1) padded_dim = intrinsics.federated_value(int(self._padded_dim), placements.SERVER) measurements = collections.OrderedDict( l2_clip=l2_clip_metrics['clipping_norm'], scale_factor=discrete_state['scale_factor'], scaled_inflated_l2=dp_query_state.l2_norm_bound, scaled_local_stddev=dp_query_state.local_stddev, actual_num_clients=actual_num_clients, padded_dim=padded_dim, dp_query_metrics=dp_metrics['dp_query_metrics']) return intrinsics.federated_zip(measurements)
def next_computation(arg): """The logic of a single MapReduce processing round.""" server_state, client_data = arg broadcast_input = intrinsics.federated_map(mrf.prepare, server_state) broadcast_result = intrinsics.federated_broadcast(broadcast_input) work_arg = intrinsics.federated_zip([client_data, broadcast_result]) (aggregate_input, secure_sum_bitwidth_input, secure_sum_input, secure_modular_sum_input) = intrinsics.federated_map( mrf.work, work_arg) aggregate_result = intrinsics.federated_aggregate( aggregate_input, mrf.zero(), mrf.accumulate, mrf.merge, mrf.report) secure_sum_bitwidth_result = intrinsics.federated_secure_sum_bitwidth( secure_sum_bitwidth_input, mrf.secure_sum_bitwidth()) secure_sum_result = intrinsics.federated_secure_sum( secure_sum_input, mrf.secure_sum_max_input()) secure_modular_sum_result = intrinsics.federated_secure_modular_sum( secure_modular_sum_input, mrf.secure_modular_sum_modulus()) update_arg = intrinsics.federated_zip( (server_state, (aggregate_result, secure_sum_bitwidth_result, secure_sum_result, secure_modular_sum_result))) updated_server_state, server_output = intrinsics.federated_map( mrf.update, update_arg) return updated_server_state, server_output
def secure_quantized_sum(client_value, lower_bound, upper_bound): """Quantizes and sums values securely. Provided `client_value` can be either a Tensor or a nested structure of Tensors. If it is a nested structure, `lower_bound` and `upper_bound` must be either both scalars, or both have the same structure as `client_value`, with each element being a scalar, representing the bounds to be used for each corresponding Tensor in `client_value`. This method converts each Tensor in provided `client_value` to appropriate format and uses the `tff.federated_secure_sum_bitwidth` operator to realize the sum. The dtype of Tensors in provided `client_value` can be one of `[tf.int32, tf.int64, tf.float32, tf.float64]`. If the dtype of `client_value` is `tf.int32` or `tf.int64`, the summation is possibly exact, depending on `lower_bound` and `upper_bound`: In the case that `upper_bound - lower_bound < 2**32`, the summation will be exact. If it is not, `client_value` will be quantized to precision of 32 bits, so the worst case error introduced for the value of each client will be approximately `(upper_bound - lower_bound) / 2**32`. Deterministic rounding to nearest value is used in such cases. If the dtype of `client_value` is `tf.float32` or `tf.float64`, the summation is generally *not* accurate up to full floating point precision. Instead, the values are first clipped to the `[lower_bound, upper_bound]` range. These values are then uniformly quantized to 32 bit resolution, using deterministic rounding to round the values to the quantization points. Rounding happens roughly as follows (implementation is a bit more complex to mitigate numerical stability issues): ``` values = tf.round( (client_value - lower_bound) * ((2**32 - 1) / (upper_bound - lower_bound)) ``` After summation, the inverse operation if performed, so the return value is of the same dtype as the input `client_value`. In terms of accuracy, it is safe to assume accuracy within 7-8 significant digits for `tf.float32` inputs, and 8-9 significant digits for `tf.float64` inputs, where the significant digits refer to precision relative to the range of the provided bounds. Thus, these bounds should not be set extremely wide. Accuracy losses arise due to (1) quantization within the given clipping range, (2) float precision of final outputs (e.g. `tf.float32` has 23 bits in its mantissa), and (3) precision losses that arise in doing math on `tf.float32` and `tf.float64` inputs. As a concrete example, if the range is `+/- 1000`, errors up to `1e-4` per element should be expected for `tf.float32` and up to `1e-5` for `tf.float64`. Args: client_value: A `tff.Value` placed at `tff.CLIENTS`. lower_bound: The smallest possible value for `client_value` (inclusive). Values smaller than this bound will be clipped. Must be either a scalar or a nested structure of scalars, matching the structure of `client_value`. Must be either a Python constant or a `tff.Value` placed at `tff.SERVER`, with dtype matching that of `client_value`. upper_bound: The largest possible value for `client_value` (inclusive). Values greater than this bound will be clipped. Must be either a scalar or a nested structure of scalars, matching the structure of `client_value`. Must be either a Python constant or a `tff.Value` placed at `tff.SERVER`, with dtype matching that of `client_value`. Returns: Summed `client_value` placed at `tff.SERVER`, of the same dtype as `client_value`. Raises: TypeError (or its subclasses): If input arguments do not satisfy the type constraints specified above. """ # Possibly converts Python constants to federated values. client_value, lower_bound, upper_bound = _normalize_secure_quantized_sum_args( client_value, lower_bound, upper_bound) # This object is used during decoration of the `client_shift` method, and the # value stored in this mutable container is used during decoration of the # `server_shift` method. The reason for this is that we cannot currently get # the needed information out of `client_value.type_signature.member` as we # need both the `TensorType` information as well as the Python container # attached to them. temp_box = [] # These tf_computations assume the inputs were already validated. In # particular, that lower_bnd and upper_bnd have the same structure, and if not # scalar, the structure matches the structure of value. @tensorflow_computation.tf_computation() def client_shift(value, lower_bnd, upper_bnd): assert not temp_box temp_box.append(tf.nest.map_structure(lambda v: v.dtype, value)) fn = _client_tensor_shift_for_secure_sum if tf.is_tensor(lower_bnd): return tf.nest.map_structure(lambda v: fn(v, lower_bnd, upper_bnd), value) else: return tf.nest.map_structure(fn, value, lower_bnd, upper_bnd) @tensorflow_computation.tf_computation() def server_shift(value, lower_bnd, upper_bnd, summands): fn = _server_tensor_shift_for_secure_sum if tf.is_tensor(lower_bnd): return tf.nest.map_structure( lambda v, dtype: fn(summands, v, lower_bnd, upper_bnd, dtype), value, temp_box[0]) else: return tf.nest.map_structure(lambda *args: fn(summands, *args), value, lower_bnd, upper_bnd, temp_box[0]) client_one = intrinsics.federated_value(1, placements.CLIENTS) # Orchestration. client_lower_bound = intrinsics.federated_broadcast(lower_bound) client_upper_bound = intrinsics.federated_broadcast(upper_bound) value = intrinsics.federated_map( client_shift, (client_value, client_lower_bound, client_upper_bound)) num_summands = intrinsics.federated_secure_sum_bitwidth( client_one, bitwidth=1) secagg_value_type = value.type_signature.member assert secagg_value_type.is_tensor() or secagg_value_type.is_struct() if secagg_value_type.is_tensor(): bitwidths = 32 else: bitwidths = structure.map_structure(lambda t: 32, secagg_value_type) value = intrinsics.federated_secure_sum_bitwidth(value, bitwidth=bitwidths) value = intrinsics.federated_map( server_shift, (value, lower_bound, upper_bound, num_summands)) return value
def secure_aggregation(): data_at_clients = intrinsics.federated_value(1, placements.CLIENTS) bitwidth = 1 return intrinsics.federated_secure_sum_bitwidth(data_at_clients, bitwidth)
def test_type_signature_with_int(self): value = intrinsics.federated_value(1, placements.CLIENTS) bitwidth = 8 result = intrinsics.federated_secure_sum_bitwidth(value, bitwidth) self.assert_value(result, 'int32@SERVER')
def test_returns_map_reduce_form_with_secure_sum_bitwidth(self): mrf = self.get_map_reduce_form_for_client_to_server_fn( lambda data: intrinsics.federated_secure_sum_bitwidth(data, 7)) self.assertEqual(mrf.secure_sum_bitwidth(), (7, ))
def test_type_signature_with_structure_of_ints_scalar_bitwidth(self): value = intrinsics.federated_value([1, [1, 1]], placements.CLIENTS) bitwidth = 8 result = intrinsics.federated_secure_sum_bitwidth(value, bitwidth) self.assert_value(result, '<int32,<int32,int32>>@SERVER')
def test_type_signature_with_one_tensor_and_bitwidth(self): value = intrinsics.federated_value( np.ndarray(shape=(5, 37), dtype=np.int16), placements.CLIENTS) bitwidth = 2 result = intrinsics.federated_secure_sum_bitwidth(value, bitwidth) self.assert_value(result, 'int16[5,37]@SERVER')
def test_raises_type_error_with_bitwidth_int_at_server(self): value = intrinsics.federated_value(1, placements.CLIENTS) bitwidth = intrinsics.federated_value(1, placements.SERVER) with self.assertRaises(TypeError): intrinsics.federated_secure_sum_bitwidth(value, bitwidth)
def test_raises_type_error_with_different_structures(self): value = intrinsics.federated_value([1, [1, 1]], placements.CLIENTS) bitwidth = [8, 4, 2] with self.assertRaises(TypeError): intrinsics.federated_secure_sum_bitwidth(value, bitwidth)
def sum_with_bitwidth(arg): return intrinsics.federated_secure_sum_bitwidth(arg, bitwidth)