def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Note that in addition to updating the server weights according to the client model weight deltas, we extract metrics (governed by the `monitor` attribute of the `client_lr_callback` and `server_lr_callback` attributes of the `server_state`) and use these to update the client learning rate callbacks. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation` before and during local client training. """ client_model = tff.federated_broadcast(server_state.model) client_lr = tff.federated_broadcast( server_state.client_lr_callback.learning_rate) if dataset_preprocess_comp is not None: federated_dataset = tff.federated_map(dataset_preprocess_comp, federated_dataset) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_lr)) client_weight = client_outputs.client_weight aggregated_gradients = tff.federated_mean( client_outputs.accumulated_gradients, weight=client_weight) initial_aggregated_outputs = dummy_model.federated_output_computation( client_outputs.initial_model_output) if isinstance(initial_aggregated_outputs.type_signature, tff.StructType): initial_aggregated_outputs = tff.federated_zip( initial_aggregated_outputs) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) client_monitor_value = initial_aggregated_outputs[client_monitor] server_monitor_value = initial_aggregated_outputs[server_monitor] server_state = tff.federated_map( server_update_fn, (server_state, aggregated_gradients, client_monitor_value, server_monitor_value)) result = collections.OrderedDict( before_training=initial_aggregated_outputs, during_training=aggregated_outputs) return server_state, result
def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" # TODO(b/131429028): The federated_zip should be automatic. from_server = tff.federated_zip( gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights)) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) if gan.dp_averaging_fn is None: # Not using differential privacy. new_dp_averaging_state = server_state.dp_averaging_state averaged_discriminator_weights_delta = tff.federated_mean( client_outputs.discriminator_weights_delta, weight=client_outputs.update_weight) else: # Using differential privacy. Note that the weight argument is set to None # here. This is because the DP aggregation code explicitly does not do # weighted aggregation. (If weighted aggregation is desired, differential # privacy needs to be turned off.) new_dp_averaging_state, averaged_discriminator_weights_delta = ( gan.dp_averaging_fn(server_state.dp_averaging_state, client_outputs.discriminator_weights_delta, weight=None)) # TODO(b/131085687): Perhaps reconsider the choice to also use # ClientOutput to hold the aggregated client output. aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight), counters=tff.federated_sum(client_outputs.counters)) # TODO(b/131839522): This federated_zip shouldn't be needed. aggregated_client_output = tff.federated_zip(aggregated_client_output) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_dp_averaging_state)) return server_state
def server_init(): initial_model, server_optimizer_state = tff.federated_eval( server_init_tf, tff.SERVER) return tff.federated_zip( ServerState(model=initial_model, optimizer_state=server_optimizer_state, delta_aggregate_state=aggregation_process_init()))
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_apply(server_update_fn, (server_state, round_model_delta)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) round_grads_norm = tff.federated_mean( client_outputs.optimizer_output.flat_grads_norm_sum, weight=weight_denom) server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, round_grads_norm)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def next_fn(state, deltas, weights): @tff.tf_computation(model_update_type) def clip_by_global_norm(update): clipped_update, global_norm = tf.clip_by_global_norm( tf.nest.flatten(update), tf.constant(clip_norm)) was_clipped = tf.cond( tf.greater(global_norm, tf.constant(clip_norm)), lambda: tf.constant(1), lambda: tf.constant(0), ) clipped_update = tf.nest.pack_sequence_as(update, clipped_update) return clipped_update, global_norm, was_clipped clipped_deltas, client_norms, client_was_clipped = tff.federated_map( clip_by_global_norm, deltas) return collections.OrderedDict( state=state, result=tff.federated_mean(clipped_deltas, weight=weights), measurements=tff.federated_zip( NormClippedAggregationMetrics( max_global_norm=tff.utils.federated_max(client_norms), num_clipped=tff.federated_sum(client_was_clipped), )))
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of TrieHH computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: An updated `ServerState` """ discovered_prefixes = tff.federated_broadcast( server_state.discovered_prefixes) round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, tff.federated_zip([federated_dataset, discovered_prefixes, round_num])) accumulated_votes = tff.federated_sum(client_outputs.client_votes) server_state = tff.federated_map(server_update_fn, (server_state, accumulated_votes)) server_output = tff.federated_value([], tff.SERVER) return server_state, server_output
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean( client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def comp(temperatures, threshold): return tff.federated_mean( tff.federated_map( count_over, tff.federated_zip( [temperatures, tff.federated_broadcast(threshold)])), tff.federated_map(count_total, temperatures))
def fed_server_initial_state(): state = tff.federated_eval(build_server_initial_state_comp(gan), tff.SERVER) server_initial_state = tff.federated_zip( gan_training_tf_fns.ServerState( state.generator_weights, state.discriminator_weights, state.counters, aggregation_state=gan.aggregation_process.initialize())) return server_initial_state
def server_init_tff(): """Returns a `reconstruction_utils.ServerState` placed at `tff.SERVER`.""" tf_init_tuple = tff.federated_eval(server_init_tf, tff.SERVER) aggregation_process_init = aggregation_process.initialize() return tff.federated_zip( reconstruction_utils.ServerState( model=tf_init_tuple[0], optimizer_state=tf_init_tuple[1], round_num=tf_init_tuple[2], aggregator_state=aggregation_process_init))
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_optimizer_weights = tff.federated_broadcast( server_state.optimizer_state) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num, client_optimizer_weights)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean(client_outputs.weights_delta, weight=client_weight) implicit_mean_broad = tff.federated_broadcast( tff.federated_mean(client_outputs.weights, weight=client_weight)) @tff.tf_computation(model_weights_type.trainable, model_weights_type.trainable) def individual_drift_fn(client_model, implicit_mean): drifts = tf.nest.map_structure(lambda a, b: a - b, client_model, implicit_mean) return drifts drifts = tff.federated_map( individual_drift_fn, (client_outputs.weights, implicit_mean_broad)) @tff.tf_computation(model_weights_type.trainable) def individual_drift_norm_fn(delta_from_mean): drift_norm = tf.reduce_sum( tf.nest.flatten( tf.nest.map_structure(lambda a: tf.nn.l2_loss(a), delta_from_mean))) return drift_norm drift_norms = tff.federated_map(individual_drift_norm_fn, drifts) client_drift = tff.federated_mean(drift_norms, weight=client_weight) server_state = tff.federated_map( server_update_fn, (server_state, model_delta, client_drift)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def fed_server_initial_state(): state = tff.federated_eval(build_server_initial_state_comp(gan), tff.SERVER) dp_averaging_state = ( state.dp_averaging_state if gan.dp_averaging_fn is None else gan.dp_averaging_fn.initialize()) server_initial_state = tff.federated_zip( gan_training_tf_fns.ServerState( state.generator_weights, state.discriminator_weights, state.counters, dp_averaging_state=dp_averaging_state)) return server_initial_state
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ # Run computation on the clients. client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) # Aggregate model deltas. client_weight = client_outputs.client_weight if mask_zeros_in_client_updates: model_delta = federated_mean_masked(client_outputs.weights_delta, client_weight) else: model_delta = tff.federated_mean(client_outputs.weights_delta, client_weight) server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) # Aggregate model outputs that contain local metrics and various statistics. aggregated_outputs = placeholder_model.federated_output_computation( client_outputs.model_output) additional_outputs = tff.federated_mean( client_outputs.additional_output, weight=client_weight) @tff.tf_computation(aggregated_outputs.type_signature.member, additional_outputs.type_signature.member) def _update_aggregated_outputs(aggregated_outputs, additional_outputs): aggregated_outputs.update(additional_outputs) return aggregated_outputs aggregated_outputs = tff.federated_map( _update_aggregated_outputs, (aggregated_outputs, additional_outputs)) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset, malicious_dataset, malicious_clients): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. consisting of malicious datasets. malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, malicious_dataset, malicious_clients, client_model)) weight_denom = client_outputs.weights_delta_weight # If the aggregation process' next function takes three arguments it is # weighted, otherwise, unweighted. Unfortunately there is no better way # to determine this. if len(aggregation_process.next.type_signature.parameter) == 3: aggregate_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) else: aggregate_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta) new_delta_aggregate_state = aggregate_output.state round_model_delta = aggregate_output.result server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, new_delta_aggregate_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_optimizer_state = tff.federated_broadcast( server_state.client_optimizer_state) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_optimizer_state, client_round_num)) client_model_weight = client_outputs.client_weight.model_weight client_opt_weight = client_outputs.client_weight.optimizer_weight model_delta = tff.federated_mean( client_outputs.weights_delta, weight=client_model_weight) # We convert the optimizer state to a float type so that it can be used # with thing such as `tff.federated_mean`. This is only necessary because # `tf.keras.Optimizer` objects have a state with an integer indicating # the number of times it has been applied. client_optimizer_state_delta = tff.federated_map( _convert_opt_state_to_float, client_outputs.optimizer_state_delta) client_optimizer_state_delta = optimizer_aggregator( client_optimizer_state_delta, weight=client_opt_weight) # We conver the optimizer state back into one with an integer round number client_optimizer_state_delta = tff.federated_map( _convert_opt_state_to_int, client_optimizer_state_delta) server_state = tff.federated_map( server_update_fn, (server_state, model_delta, client_optimizer_state_delta)) aggregated_outputs = placeholder_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def next_fn(state, deltas, weights=None): @tff.tf_computation(deltas.type_signature.member, tf.float32) def clip_by_global_norm(delta, clip_norm): clipped, global_norm = tf.clip_by_global_norm( tf.nest.flatten(delta), clip_norm) clipped_deltas = tf.nest.pack_sequence_as(delta, clipped) return clipped_deltas, global_norm client_clip_norm = tff.federated_broadcast(state.clip_norm) clipped_deltas, client_norms = tff.federated_map( clip_by_global_norm, (deltas, client_clip_norm)) # clip_norm no-op update here but could be set using max_norm. next_state = tff.federated_zip( ClipNormAggregateState( clip_norm=state.clip_norm, max_norm=tff.utils.federated_max(client_norms))) return next_state, tff.federated_mean(clipped_deltas, weight=weights)
def run_one_round(server_state, federated_dataset, malicious_dataset, malicious_clients): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. consisting of malicious datasets. malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, malicious_dataset, malicious_clients, client_model)) weight_denom = client_outputs.weights_delta_weight if aggregation_process.is_weighted: aggregate_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) else: aggregate_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta) new_delta_aggregate_state = aggregate_output.state round_model_delta = aggregate_output.result server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, new_delta_aggregate_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def next_fn(state, deltas, weights=None): @tff.tf_computation(deltas.type_signature.member, tf.float32) def clip_by_global_norm(delta, clip_norm): # TODO(b/123092620): Replace anonymous_tuple with tf.nest. delta = anonymous_tuple.from_container(delta) clipped, global_norm = tf.clip_by_global_norm( anonymous_tuple.flatten(delta), clip_norm) return anonymous_tuple.pack_sequence_as(delta, clipped), global_norm client_clip_norm = tff.federated_broadcast(state.clip_norm) clipped_deltas, client_norms = tff.federated_map( clip_by_global_norm, (deltas, client_clip_norm)) # clip_norm no-op update here but could be set using max_norm. next_state = tff.federated_zip( ClipNormAggregateState( clip_norm=state.clip_norm, max_norm=tff.utils.federated_max(client_norms))) return next_state, tff.federated_mean(clipped_deltas, weight=weights)
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean(client_outputs.weights_delta, weight=client_weight) global_cor_states = tff.federated_mean(client_outputs.local_cor_states, weight=client_weight) if correction == 'joint': server_state = tff.federated_map( server_update_joint_cor_fn, (server_state, model_delta, global_cor_states)) elif correction == 'local': server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) else: raise TypeError('Correction method must be local or joint.') aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def next_fn(state, value): clip_range_lower, clip_range_upper = self._get_clip_range() # Modular clip values before aggregation. clipped_value = tff.federated_map( modular_clip_by_value_tff, (value, tff.federated_broadcast(clip_range_lower), tff.federated_broadcast(clip_range_upper))) (agg_output_state, agg_output_result, agg_output_measurements) = inner_agg_next(state, clipped_value) # Clip the aggregate to the same range again (not considering summands). clipped_agg_output_result = tff.federated_map( modular_clip_by_value_tff, (agg_output_result, clip_range_lower, clip_range_upper)) measurements = collections.OrderedDict( agg_process=agg_output_measurements) return tff.templates.MeasuredProcessOutput( state=agg_output_state, result=clipped_agg_output_result, measurements=tff.federated_zip(measurements))
def foo(x): return tff.federated_zip(x)
def mean_over_threshold(temperatures, threshold): client_data = tff.federated_broadcast(threshold) client_data = tff.federated_zip([temperatures, client_data]) result_map = tff.federated_map(count_over, client_data) count_map = tff.federated_map(count_total, temperatures) return tff.federated_mean(result_map, count_map)
def run_one_round(server_state, federated_dataset, ids): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num, ids)) client_weight = client_outputs.client_weight client_id = client_outputs.client_id #LOSS SELECTION: # losses_at_server = tff.federated_collect(client_outputs.model_output) # weights_at_server = tff.federated_collect(client_weight) @computations.tf_computation def zeros_fn(): return tf.zeros(shape=[total_clients, 1], dtype=tf.float32) zero = zeros_fn() at_server_type = tff.TensorType(shape=[total_clients, 1], dtype=tf.float32) # list_type = tff.SequenceType( tff.TensorType(dtype=tf.float32)) client_output_type = client_update_fn.type_signature.result @computations.tf_computation(at_server_type, client_output_type) def accumulate_weight(u, t): value = t.client_weight index = t.client_id new_u = tf.tensor_scatter_nd_update(u, index, value) return new_u @computations.tf_computation(at_server_type, client_output_type) def accumulate_loss(u, t): value = tf.reshape(tf.math.reduce_sum(t.model_output['loss']), shape=[1, 1]) index = t.client_id new_u = tf.tensor_scatter_nd_update(u, index, value) return new_u # output_at_server= tff.federated_collect(client_outputs) weights_at_server = tff.federated_reduce(client_outputs, zero, accumulate_weight) losses_at_server = tff.federated_reduce(client_outputs, zero, accumulate_loss) #losses_at_server = tff.federated_aggregate(client_outputs.model_output, zero, accumulate, merge, report) selected_clients_weights = tff.federated_map( zero_small_loss_clients, (losses_at_server, weights_at_server, server_state.effective_num_clients)) # selected_clients_weights_at_client = tff.federated_broadcast(selected_clients_weights) selected_clients_weights_broadcast = tff.federated_broadcast( selected_clients_weights) selected_clients_weights_at_client = tff.federated_map( select_weight_fn, (selected_clients_weights_broadcast, ids)) aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, selected_clients_weights_at_client) # model_delta = tff.federated_mean( # client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map( server_update_fn, (server_state, aggregation_output.result)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs