def next_fn(state, value): encode_params, decode_before_sum_params, decode_after_sum_params = ( intrinsics.federated_map(get_params_fn, state)) encode_params = intrinsics.federated_broadcast(encode_params) decode_before_sum_params = intrinsics.federated_broadcast( decode_before_sum_params) encoded_values = intrinsics.federated_map( encode_fn, [value, encode_params, decode_before_sum_params]) aggregated_values = intrinsics.federated_aggregate( encoded_values, zero_fn(), accumulate_fn, merge_fn, report_fn) decoded_values = intrinsics.federated_map( decode_after_sum_fn, [aggregated_values.values, decode_after_sum_params]) updated_state = intrinsics.federated_map( update_state_fn, [state, aggregated_values.state_update_tensors]) empty_metrics = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput( state=updated_state, result=decoded_values, measurements=empty_metrics)
def next_fn(state, value): server_scale_factor = state['scale_factor'] client_scale_factor = intrinsics.federated_broadcast( server_scale_factor) server_prior_norm_bound = state['prior_norm_bound'] prior_norm_bound = intrinsics.federated_broadcast( server_prior_norm_bound) discretized_value = intrinsics.federated_map( discretize_fn, (value, client_scale_factor, prior_norm_bound)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_scale_factor)) new_state = collections.OrderedDict( scale_factor=server_scale_factor, prior_norm_bound=server_prior_norm_bound, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( discretize=inner_agg_output.measurements) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements))
def next_fn(state, value): clip_lower = intrinsics.federated_value(self._clip_range_lower, placements.SERVER) clip_upper = intrinsics.federated_value(self._clip_range_upper, placements.SERVER) # Modular clip values before aggregation. clipped_value = intrinsics.federated_map( modular_clip_by_value_fn, (value, intrinsics.federated_broadcast(clip_lower), intrinsics.federated_broadcast(clip_upper))) inner_agg_output = inner_agg_next(state, clipped_value) # Clip the aggregate to the same range again (not considering summands). clipped_agg_output_result = intrinsics.federated_map( modular_clip_by_value_fn, (inner_agg_output.result, clip_lower, clip_upper)) measurements = collections.OrderedDict( modclip=inner_agg_output.measurements) if self._estimate_stddev: estimate = intrinsics.federated_map( estimator_fn, (clipped_agg_output_result, clip_lower, clip_upper)) measurements['estimated_stddev'] = estimate return measured_process.MeasuredProcessOutput( state=inner_agg_output.state, result=clipped_agg_output_result, measurements=intrinsics.federated_zip(measurements))
def _sum_securely(self, value, upper_bound, lower_bound): """Securely sums `value` placed at CLIENTS.""" if self._config_mode == _Config.INT: value = intrinsics.federated_map( _client_shift, (value, intrinsics.federated_broadcast(upper_bound), intrinsics.federated_broadcast(lower_bound))) value = intrinsics.federated_secure_sum(value, self._secagg_bitwidth) num_summands = intrinsics.federated_sum(_client_one()) value = intrinsics.federated_map(_server_shift, (value, lower_bound, num_summands)) return value elif self._config_mode == _Config.FLOAT: return primitives.secure_quantized_sum(value, lower_bound, upper_bound) else: raise ValueError(f'Unexpected internal config type: {self._config_mode}')
def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None): clipping_norm_state, agg_state, clipped_count_state = state clipping_norm = clipping_norm_process.report(clipping_norm_state) clipped_value, global_norm, was_clipped = intrinsics.federated_map( clip_fn, (value, intrinsics.federated_broadcast(clipping_norm))) new_clipping_norm_state = clipping_norm_process.next( clipping_norm_state, global_norm) if weight is None: agg_output = inner_agg_process.next(agg_state, clipped_value) else: agg_output = inner_agg_process.next(agg_state, clipped_value, weight) clipped_count_output = clipped_count_agg_process.next( clipped_count_state, was_clipped) new_state = collections.OrderedDict([ (prefix('ing_norm'), new_clipping_norm_state), ('inner_agg', agg_output.state), (prefix('ed_count_agg'), clipped_count_output.state) ]) measurements = collections.OrderedDict([ (prefix('ing'), agg_output.measurements), (prefix('ing_norm'), clipping_norm), (prefix('ed_count'), clipped_count_output.result) ]) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements))
def add_server_number_plus_one(server_number, client_numbers): one = intrinsics.federated_value(1, placements.SERVER) server_context = intrinsics.federated_map(add, (one, server_number)) client_context = intrinsics.federated_broadcast(server_context) return intrinsics.federated_map(add, (client_context, client_numbers))
def _train_one_round(model, federated_data): locally_trained_models = intrinsics.federated_map( _train_on_one_client, collections.OrderedDict([('model', intrinsics.federated_broadcast(model)), ('batches', federated_data)])) return intrinsics.federated_mean(locally_trained_models)
def stateful_broadcast(state, value): test_metrics = intrinsics.federated_value( 3.0, placements.SERVER) return measured_process_lib.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=test_metrics)
def next_fn(state, value): server_step_size = state['step_size'] client_step_size = intrinsics.federated_broadcast(server_step_size) discretized_value = intrinsics.federated_map(discretize_fn, (value, client_step_size)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_step_size)) new_state = collections.OrderedDict( step_size=server_step_size, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( deterministic_discretization=inner_agg_output.measurements) if self._distortion_aggregation_factory is not None: distortions = intrinsics.federated_map(distortion_measurement_fn, (value, client_step_size)) aggregate_distortion = distortion_aggregation_process.next( distortion_aggregation_process.initialize(), distortions).result measurements['distortion'] = aggregate_distortion return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements))
def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip([ temperatures, intrinsics.federated_broadcast(threshold) ])), intrinsics.federated_map(count_total, temperatures))
def encoded_broadcast_comp(state, value): """Encoded broadcast federated_computation.""" empty_metrics = intrinsics.federated_value((), placements.SERVER) new_state, encoded_value = intrinsics.federated_map(encode_fn, (state, value)) client_encoded_value = intrinsics.federated_broadcast(encoded_value) client_value = intrinsics.federated_map(decode_fn, client_encoded_value) return measured_process.MeasuredProcessOutput( state=new_state, result=client_value, measurements=empty_metrics)
def next_comp(state, value): return measured_process.MeasuredProcessOutput( state=intrinsics.federated_map(_add_one, state), result=intrinsics.federated_broadcast(value), # Arbitrary metrics for testing. measurements=intrinsics.federated_map( tensorflow_computation.tf_computation( lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0), value))
def computation(arg): server_data, client_data = arg context_at_server = intrinsics.federated_map(bf.compute_server_context, server_data) context_at_clients = intrinsics.federated_broadcast(context_at_server) client_processing_arg = intrinsics.federated_zip( (context_at_clients, client_data)) return intrinsics.federated_map(bf.client_processing, client_processing_arg)
def next_fn(server_state, client_val): """`next` function for `tff.templates.IterativeProcess`.""" server_update = intrinsics.federated_sum(client_val) server_output = intrinsics.federated_value((), placements.SERVER) state_at_clients = intrinsics.federated_broadcast(server_state) lambda_returning_sum = computation_returning_lambda() sum_fn = lambda_returning_sum(1) server_output = sum_fn(state_at_clients) return server_update, server_output
def next_fn(state, weights, client_data): round_num_at_clients = intrinsics.federated_broadcast(state) client_result, model_outputs = intrinsics.federated_map( client_update_computation, (weights, client_data, round_num_at_clients)) updated_state = intrinsics.federated_map(add_one, state) train_metrics = metrics_aggregation_fn(model_outputs) measurements = intrinsics.federated_zip( collections.OrderedDict(train=train_metrics)) return measured_process.MeasuredProcessOutput(updated_state, client_result, measurements)
def aggregation_comp(server_arg, client_arg): client_sums = intrinsics.federated_map(compute_tuple_sum, client_arg) summed_client_value = intrinsics.federated_sum(client_sums) broadcast_sum = intrinsics.federated_broadcast(summed_client_value) # Adding a function call here requires normalization into CDF before # checking the aggregation-dependence condition. client_tuple = package_args_as_tuple(client_sums, broadcast_sum) summed_client_value = intrinsics.federated_sum(client_tuple[0]) return intrinsics.federated_map(compute_sum, (server_arg, summed_client_value))
def next_comp(server_state, client_data): del server_state, client_data client_val = intrinsics.federated_value(0, placements.CLIENTS) server_agg = intrinsics.federated_sum(client_val) # This broadcast is dependent on the result of the above aggregation, # which is not supported by MapReduce form. broadcasted = intrinsics.federated_broadcast(server_agg) server_agg_again = intrinsics.federated_sum(broadcasted) # `next` must return two values. return server_agg_again, intrinsics.federated_value((), placements.SERVER)
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates) # No call to `federated_secure_sum_bitwidth`. s6 = intrinsics.federated_zip([server_state, unsecure_update]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output
def _compute_measurements(self, upper_bound, lower_bound, value_max, value_min): """Creates measurements to be reported. All values are summed securely.""" is_max_clipped = intrinsics.federated_map( computations.tf_computation( lambda bound, value: tf.cast(bound < value, COUNT_TF_TYPE)), (intrinsics.federated_broadcast(upper_bound), value_max)) max_clipped_count = intrinsics.federated_secure_sum(is_max_clipped, bitwidth=1) is_min_clipped = intrinsics.federated_map( computations.tf_computation( lambda bound, value: tf.cast(bound > value, COUNT_TF_TYPE)), (intrinsics.federated_broadcast(lower_bound), value_min)) min_clipped_count = intrinsics.federated_secure_sum(is_min_clipped, bitwidth=1) measurements = collections.OrderedDict( secure_upper_clipped_count=max_clipped_count, secure_lower_clipped_count=min_clipped_count, secure_upper_threshold=upper_bound, secure_lower_threshold=lower_bound) return intrinsics.federated_zip(measurements)
def next_fn(server_state, client_val): """`next` function for `tff.templates.IterativeProcess`.""" server_update = intrinsics.federated_zip( collections.OrderedDict( num_clients=count_clients_federated(client_val))) server_output = intrinsics.federated_value((), placements.SERVER) server_output = intrinsics.federated_sum( _bind_tf_function(intrinsics.federated_broadcast(server_state), tf.timestamp)) return server_update, server_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) unused_client_input, to_broadcast = broadcast_and_return_arg_and_result(s2) client_input = intrinsics.federated_broadcast(to_broadcast) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum_bitwidth( client_updates[1], 8) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" s2 = intrinsics.federated_map(prepare, server_state) client_input = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([client_data, client_input]) client_updates = intrinsics.federated_map(work, c3) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum_bitwidth( client_updates[1], 8) new_server_state = intrinsics.federated_zip( [unsecure_update, secure_update]) # No call to `federated_map` with an `update` function. server_output = intrinsics.federated_value([], placements.SERVER) return new_server_state, server_output
def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = intrinsics.federated_broadcast(server_model_weights) client_final_metrics = intrinsics.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. sampling_output = aggregation_process.next( aggregation_process.initialize(), # No state. client_final_metrics) # In the future we may want to output `sampling_output.measurements` also # but currently it is empty. return sampling_output.result
def server_eval(server_model_weights, federated_dataset): if broadcast_process is not None: # TODO(b/179091838): Confirm that the process has no state. # TODO(b/179091838): Zip the measurements from the broadcast_process with # the result of `model.federated_output_computation` below to avoid # dropping these metrics. broadcast_output = broadcast_process.next( broadcast_process.initialize(), server_model_weights) client_outputs = intrinsics.federated_map( client_eval, (broadcast_output.result, federated_dataset)) else: client_outputs = intrinsics.federated_map(client_eval, [ intrinsics.federated_broadcast(server_model_weights), federated_dataset ]) return model.federated_output_computation(client_outputs.local_outputs)
def next_fn(state, value): quantile_query_state, agg_state = state params = intrinsics.federated_broadcast( intrinsics.federated_map(derive_sample_params, quantile_query_state)) quantile_record = intrinsics.federated_map(get_quantile_record, (params, value)) quantile_agg_output = quantile_agg_process.next( agg_state, quantile_record) _, new_quantile_query_state, _ = intrinsics.federated_map( get_noised_result, (quantile_agg_output.result, quantile_query_state)) return intrinsics.federated_zip( (new_quantile_query_state, quantile_agg_output.state))
def next_fn(state, value): quantile_query_state, agg_state = state params = intrinsics.federated_broadcast( intrinsics.federated_map(derive_sample_params, quantile_query_state)) quantile_record = intrinsics.federated_map(get_quantile_record, (params, value)) (new_agg_state, agg_result, agg_measurements) = quantile_agg_process.next(agg_state, quantile_record) # We expect the quantile record aggregation process to be something simple # like basic sum, so we won't surface its measurements. del agg_measurements _, new_quantile_query_state = intrinsics.federated_map( get_noised_result, (agg_result, quantile_query_state)) return intrinsics.federated_zip((new_quantile_query_state, new_agg_state))
def next_computation(arg): """The logic of a single MapReduce processing round.""" s1 = arg[0] c1 = arg[1] s2 = intrinsics.federated_map(mrf.prepare, s1) c2 = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([c1, c2]) c4 = intrinsics.federated_map(mrf.work, c3) c5 = c4[0] c6 = c4[1] s3 = intrinsics.federated_aggregate(c5, mrf.zero(), mrf.accumulate, mrf.merge, mrf.report) s4 = intrinsics.federated_secure_sum_bitwidth(c6, mrf.bitwidth()) s5 = intrinsics.federated_zip([s3, s4]) s6 = intrinsics.federated_zip([s1, s5]) s7 = intrinsics.federated_map(mrf.update, s6) s8 = s7[0] s9 = s7[1] return s8, s9
def next_fn_impl(state, value): inner_state, my_state = state client_my_state = intrinsics.federated_broadcast(my_state) projected_value = intrinsics.federated_map(client_transform, (value, client_my_state)) inner_agg_output = inner_agg_process.next(inner_state, projected_value) aggregate_value = intrinsics.federated_map( server_transform, (inner_agg_output.result, my_state)) new_state = (inner_agg_output.state, intrinsics.federated_map(update_my_state, my_state)) measurements = collections.OrderedDict([(name, inner_agg_output.measurements)]) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=aggregate_value, measurements=intrinsics.federated_zip(measurements))
def server_eval(server_model_weights, federated_dataset): client_eval = build_local_evaluation(model_fn, model_weights_type, batch_type, use_experimental_simulation_loop) if broadcast_process is not None: # TODO(b/179091838): Zip the measurements from the broadcast_process with # the result of `model_metrics` below to avoid dropping these metrics. broadcast_output = broadcast_process.next( broadcast_process.initialize(), server_model_weights) client_outputs = intrinsics.federated_map( client_eval, (broadcast_output.result, federated_dataset)) else: client_outputs = intrinsics.federated_map(client_eval, [ intrinsics.federated_broadcast(server_model_weights), federated_dataset ]) model_metrics = metrics_aggregation_computation( client_outputs.local_outputs) return intrinsics.federated_zip( collections.OrderedDict(eval=model_metrics))
def server_eval(server_model_weights, federated_dataset): if broadcast_process is not None: # TODO(b/179091838): Zip the measurements from the broadcast_process with # the result of `model.federated_output_computation` below to avoid # dropping these metrics. broadcast_output = broadcast_process.next( broadcast_process.initialize(), server_model_weights) client_outputs = intrinsics.federated_map( client_eval, (broadcast_output.result, federated_dataset)) else: client_outputs = intrinsics.federated_map(client_eval, [ intrinsics.federated_broadcast(server_model_weights), federated_dataset ]) model_metrics = model.federated_output_computation( client_outputs.local_outputs) statistics = collections.OrderedDict( num_examples=intrinsics.federated_sum(client_outputs.num_examples)) return intrinsics.federated_zip( collections.OrderedDict(eval=model_metrics, stat=statistics))