def broadcast_next_fn(state, value): @computations.tf_computation(tf.int32) def add_one(value): return value + 1 return intrinsics.federated_zip( collections.OrderedDict([ ('call_count', intrinsics.federated_map(add_one, state.call_count)) ])), intrinsics.federated_broadcast(value)
def initialize_computation(): """Orchestration logic for server model initialization.""" initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval( server_init, placements.SERVER) return intrinsics.federated_zip( ServerState(model=initial_global_model, optimizer_state=initial_global_optimizer_state, delta_aggregate_state=aggregation_process.initialize(), model_broadcast_state=broadcast_process.initialize()))
def comp(temperatures, threshold): client_data = [ temperatures, intrinsics.federated_broadcast(threshold) ] result_map = intrinsics.federated_map( count_over, intrinsics.federated_zip(client_data)) count_map = intrinsics.federated_map(count_total, temperatures) return intrinsics.federated_mean(result_map, count_map)
def next_computation(arg): """The logic of a single MapReduce processing round.""" s1 = arg[0] c1 = arg[1] s2 = intrinsics.federated_map(cf.prepare, s1) c2 = intrinsics.federated_broadcast(s2) c3 = intrinsics.federated_zip([c1, c2]) c4 = intrinsics.federated_map(cf.work, c3) c5 = c4[0] c6 = c4[1] s3 = intrinsics.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge, cf.report) s4 = intrinsics.federated_secure_sum(c6, cf.bitwidth()) s5 = intrinsics.federated_zip([s3, s4]) s6 = intrinsics.federated_zip([s1, s5]) s7 = intrinsics.federated_map(cf.update, s6) s8 = s7[0] s9 = s7[1] return s8, s9
def comp(): return intrinsics.federated_zip( collections.OrderedDict([ ('A', intrinsics.federated_value(10, placement_literals.CLIENTS)), ('B', intrinsics.federated_value(20, placement_literals.CLIENTS)), ]))
def one_round_computation(server_state, federated_dataset): """Orchestration logic for one round of optimization. Args: server_state: a `tff.learning.framework.ServerState` named tuple. federated_dataset: a federated `tf.Dataset` with placement tff.CLIENTS. Returns: A tuple of updated `tff.learning.framework.ServerState` and the result of `tff.learning.Model.federated_output_computation`, both having `tff.SERVER` placement. """ broadcast_output = broadcast_process.next( server_state.model_broadcast_state, server_state.model) client_outputs = intrinsics.federated_map( _compute_local_training_and_client_delta, (federated_dataset, broadcast_output.result)) if len(aggregation_process.next.type_signature.parameter) == 3: aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, client_outputs.weights_delta_weight) else: aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta) new_global_model, new_optimizer_state = intrinsics.federated_map( server_update, (server_state.model, aggregation_output.result, server_state.optimizer_state)) new_server_state = intrinsics.federated_zip( ServerState(new_global_model, new_optimizer_state, aggregation_output.state, broadcast_output.state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) optimizer_outputs = intrinsics.federated_sum( client_outputs.optimizer_output) measurements = intrinsics.federated_zip( collections.OrderedDict( broadcast=broadcast_output.measurements, aggregation=aggregation_output.measurements, train=aggregated_outputs, stat=optimizer_outputs)) return new_server_state, measurements
def next_fn(server_state, client_data): """The `next` function for `computation_utils.IterativeProcess`.""" del server_state # Unused client_updates, client_output = intrinsics.federated_map( work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) s5 = intrinsics.federated_zip([unsecure_update, secure_update]) new_server_state, server_output = intrinsics.federated_map(update, s5) return new_server_state, server_output, client_output
def next_fn(server_val, client_val): """Defines a series of federated computations compatible with CanonicalForm.""" broadcast_val = intrinsics.federated_broadcast(server_val) values_on_clients = intrinsics.federated_zip( (client_val, broadcast_val)) result_on_clients = intrinsics.federated_map(add_two, values_on_clients) aggregated_result = intrinsics.federated_mean(result_on_clients) side_output = intrinsics.federated_value([1, 2, 3, 4, 5], placements.SERVER) return aggregated_result, side_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" # No call to `federated_map` with prepare. # No call to `federated_broadcast`. client_updates, client_output = intrinsics.federated_map(work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output, client_output
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del client_data # No call to `federated_aggregate`. unsecure_update = intrinsics.federated_value(1, placements.SERVER) # No call to `federated_secure_sum`. secure_update = intrinsics.federated_value(1, placements.SERVER) s6 = intrinsics.federated_zip( [server_state, [unsecure_update, secure_update]]) new_server_state, server_output = intrinsics.federated_map(update, s6) return new_server_state, server_output
def initialize_computation(): model = model_fn() initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval( server_init_tf, tff.SERVER) return intrinsics.federated_zip( ServerState( model=initial_global_model, optimizer_state=initial_global_optimizer_state, round_num=tff.federated_value(0.0, tff.SERVER), aggregation_state=aggregation_process.initialize(), ))
def next_fn(server_state, client_val): """`next` function for `computation_utils.IterativeProcess`.""" server_update = intrinsics.federated_zip( collections.OrderedDict([('num_clients', count_clients_federated(client_val))])) server_output = intrinsics.federated_value((), placements.SERVER) server_output = _bind_federated_value( intrinsics.federated_broadcast(server_state), server_state_type, server_output) return server_update, server_output
def next_fn(server_state, client_val): """`next` function for `tff.utils.IterativeProcess`.""" server_update = intrinsics.federated_zip( collections.OrderedDict( num_clients=count_clients_federated(client_val))) server_output = intrinsics.federated_value((), placements.SERVER) server_output = intrinsics.federated_sum( _bind_tf_function(intrinsics.federated_broadcast(server_state), tf.timestamp)) return server_update, server_output
def next_fn(state, value, weight): (clipping_norm_state, agg_state, clipped_count_state, zeroed_count_state) = state clipping_norm = self._clipping_norm_process.report( clipping_norm_state) zeroing_norm = intrinsics.federated_map(self._zeroing_norm_fn, clipping_norm) (zeroed_and_clipped, global_norm, was_clipped, was_zeroed) = intrinsics.federated_map( clip_and_zero, (value, intrinsics.federated_broadcast(clipping_norm), intrinsics.federated_broadcast(zeroing_norm))) new_clipping_norm_state = self._clipping_norm_process.next( clipping_norm_state, global_norm) agg_output = inner_agg_process.next(agg_state, zeroed_and_clipped, weight) clipped_count_output = clipped_count_agg_process.next( clipped_count_state, was_clipped) zeroed_count_output = zeroed_count_agg_process.next( zeroed_count_state, was_zeroed) new_state = collections.OrderedDict( clipping_norm=new_clipping_norm_state, inner_agg=agg_output.state, clipped_count_agg=clipped_count_output.state, zeroed_count_agg=zeroed_count_output.state) measurements = collections.OrderedDict( agg_process=agg_output.measurements, clipping_norm=clipping_norm, zeroing_norm=zeroing_norm, clipped_count=clipped_count_output.result, zeroed_count=zeroed_count_output.result) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements))
def initialize_computation(): model = model_fn() initial_global_model, initial_global_optimizer_state = intrinsics.federated_eval( server_init_tf, placements.SERVER) return intrinsics.federated_zip( ServerState( model=initial_global_model, optimizer_state=initial_global_optimizer_state, round_num=tff.federated_value(0.0, tff.SERVER), effective_num_clients=intrinsics.federated_eval( get_effective_num_clients, placements.SERVER), delta_aggregate_state=aggregation_process.initialize(), ))
def comp(temperatures, threshold): @computations.tf_computation( computation_types.SequenceType(tf.float32), tf.float32) def count(ds, t): return ds.reduce( np.int32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.int32)) return intrinsics.federated_map( count, intrinsics.federated_zip( [temperatures, intrinsics.federated_broadcast(threshold)]))
def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del server_state # Unused # No call to `federated_map` with prepare. # No call to `federated_broadcast`. client_updates = intrinsics.federated_map(work, client_data) unsecure_update = intrinsics.federated_sum(client_updates[0]) secure_update = intrinsics.federated_secure_sum(client_updates[1], 8) new_server_state = intrinsics.federated_zip( [unsecure_update, secure_update]) # No call to `federated_map` with an `update` function. server_output = intrinsics.federated_value([], placements.SERVER) return new_server_state, server_output
def next_fn(state, value): query_state, agg_state = state params = intrinsics.federated_broadcast( intrinsics.federated_map(derive_sample_params, query_state)) record = intrinsics.federated_map(get_query_record, (params, value)) (new_agg_state, agg_result, agg_measurements) = record_agg_process.next(agg_state, record) result, new_query_state = intrinsics.federated_map( get_noised_result, (agg_result, query_state)) query_metrics = intrinsics.federated_map(derive_metrics, new_query_state) new_state = (new_query_state, new_agg_state) measurements = collections.OrderedDict( dp_query_metrics=query_metrics, dp=agg_measurements) return measured_process.MeasuredProcessOutput( intrinsics.federated_zip(new_state), result, intrinsics.federated_zip(measurements))
def next_fn(state, value): value_sum_output = value_sum_process.next(state, value) count = intrinsics.federated_sum( intrinsics.federated_value(1, placements.CLIENTS)) mean_value = intrinsics.federated_map( _div, (value_sum_output.result, count)) state = value_sum_output.state measurements = intrinsics.federated_zip( collections.OrderedDict( mean_value=value_sum_output.measurements)) return measured_process.MeasuredProcessOutput( state, mean_value, measurements)
def next_fn(state, value, weight): # Client computation. weighted_value = intrinsics.federated_map(_mul, (value, weight)) # Inner aggregations. value_output = value_sum_process.next(state['value_sum_process'], weighted_value) weight_output = weight_sum_process.next( state['weight_sum_process'], weight) # Server computation. weighted_mean_value = intrinsics.federated_map( _div, (value_output.result, weight_output.result)) # Output preparation. state = collections.OrderedDict( value_sum_process=value_output.state, weight_sum_process=weight_output.state) measurements = collections.OrderedDict( value_sum_process=value_output.measurements, weight_sum_process=weight_output.measurements) return measured_process.MeasuredProcessOutput( intrinsics.federated_zip(state), weighted_mean_value, intrinsics.federated_zip(measurements))
def next_fn(state, value, weight): zeroing_norm_state, agg_state = state zeroing_norm = self._zeroing_norm_process.report( zeroing_norm_state) zeroed, norm = intrinsics.federated_map( zero, (value, intrinsics.federated_broadcast(zeroing_norm))) agg_output = inner_agg_process.next(agg_state, zeroed, weight) new_zeroing_norm_state = self._zeroing_norm_process.next( zeroing_norm_state, norm) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip( (new_zeroing_norm_state, agg_output.state)), result=agg_output.result, measurements=agg_output.measurements)
def next_fn(state, value, weight): clipping_norm_state, agg_state = state clipping_norm = self._clipping_norm_process.report( clipping_norm_state) clipped_value, global_norm = intrinsics.federated_map( clip, (value, intrinsics.federated_broadcast(clipping_norm))) agg_output = inner_agg_process.next(agg_state, clipped_value, weight) new_clipping_norm_state = self._clipping_norm_process.next( clipping_norm_state, global_norm) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip( (new_clipping_norm_state, agg_output.state)), result=agg_output.result, measurements=agg_output.measurements)
def next_fn(state, value): quantile_query_state, agg_state = state params = intrinsics.federated_broadcast( intrinsics.federated_map(derive_sample_params, quantile_query_state)) quantile_record = intrinsics.federated_map(get_quantile_record, (params, value)) (new_agg_state, agg_result, agg_measurements) = quantile_agg_process.next(agg_state, quantile_record) # We expect the quantile record aggregation process to be something simple # like basic sum, so we won't surface its measurements. del agg_measurements _, new_quantile_query_state = intrinsics.federated_map( get_noised_result, (agg_result, quantile_query_state)) return intrinsics.federated_zip((new_quantile_query_state, new_agg_state))
def _compute_measurements(self, upper_bound, lower_bound, value_max, value_min): """Creates measurements to be reported. All values are summed securely.""" is_max_clipped = intrinsics.federated_map( computations.tf_computation( lambda bound, value: tf.cast(bound < value, COUNT_TF_TYPE)), (intrinsics.federated_broadcast(upper_bound), value_max)) max_clipped_count = intrinsics.federated_secure_sum(is_max_clipped, bitwidth=1) is_min_clipped = intrinsics.federated_map( computations.tf_computation( lambda bound, value: tf.cast(bound > value, COUNT_TF_TYPE)), (intrinsics.federated_broadcast(lower_bound), value_min)) min_clipped_count = intrinsics.federated_secure_sum(is_min_clipped, bitwidth=1) measurements = collections.OrderedDict( upper_bound_clipped_count=max_clipped_count, lower_bound_clipped_count=min_clipped_count, upper_bound_threshold=upper_bound, lower_bound_threshold=lower_bound) return intrinsics.federated_zip(measurements)
def next_fn(global_state, value, weight): sample_params = intrinsics.federated_broadcast( intrinsics.federated_map(derive_sample_params, global_state)) weighted_value, adj_weight, quantile_record, too_large = ( intrinsics.federated_map(preprocess_value, (sample_params, value, weight))) value_sum = intrinsics.federated_sum(weighted_value) total_weight = intrinsics.federated_sum(adj_weight) quantile_sum = intrinsics.federated_sum(quantile_record) num_zeroed = intrinsics.federated_sum(too_large) mean_value = intrinsics.federated_map(divide_no_nan, (value_sum, total_weight)) new_threshold, new_global_state = intrinsics.federated_map( next_quantile, (quantile_sum, global_state)) measurements = intrinsics.federated_zip( AdaptiveZeroingMetrics(new_threshold, num_zeroed)) return measured_process.MeasuredProcessOutput( state=new_global_state, result=mean_value, measurements=measurements)
def baz(x): val = intrinsics.federated_zip(x) self.assertIsInstance(val, value_base.Value) return val
def bar(x): arg = structure.Struct(_make_test_tuple(x, k) for k in range(n)) val = intrinsics.federated_zip(arg) self.assertIsInstance(val, value_base.Value) return val
def foo(x): arg = {str(k): x[k] for k in range(n)} val = intrinsics.federated_zip(arg) self.assertIsInstance(val, value_base.Value) return val
def _(arg): return intrinsics.federated_zip(arg)
def foo(arg): val = intrinsics.federated_zip(arg) self.assertIsInstance(val, value_base.Value) return val