def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean( client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of TrieHH computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: An updated `ServerState` """ discovered_prefixes = tff.federated_broadcast( server_state.discovered_prefixes) round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, discovered_prefixes, round_num)) accumulated_votes = tff.federated_sum(client_outputs.client_votes) accumulated_weights = tff.federated_sum(client_outputs.client_weight) server_state = tff.federated_map( server_update_fn, (server_state, accumulated_votes, accumulated_weights)) server_output = tff.federated_value([], tff.SERVER) return server_state, server_output
def federated_train(model, learning_rate, data, classes): return tff.federated_mean( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data, tff.federated_broadcast(classes) ]))
def run_one_round(server_state, client_states, federated_dataset, federated_dataset_single): from_server = FromServer(w=server_state.w, meta_w=server_state.meta_w, round_num=server_state.round_num) from_server = tff.federated_broadcast(from_server) control_output = tff.federated_map( control_computation, (federated_dataset_single, client_state, from_server)) c = tff.federated_broadcast( tff.federated_mean(control_output.w_delta, weight=control_output.client_weight)) @tff.tf_computation(client_output_type.w_delta, client_output_type.w_delta) def compute_control_input(c, c_i): correction = tf.nest.map_structure(lambda a, b: a - b, c, c_i) # if we are using SCAFFOLD then use the correction, otherwise # we just let the correction be zero. return tf.cond( tf.constant(control, dtype=tf.bool), lambda: correction, lambda: tf.nest.map_structure(tf.zeros_like, correction)) # the collection of gradient corrections corrections = tff.federated_map(compute_control_input, (c, control_output.cs)) client_outputs = tff.federated_map( client_computation, (federated_dataset, from_server, client_state, corrections)) w_delta = tff.federated_mean(client_outputs.w_delta, weight=client_outputs.client_weight) server_state = tff.federated_map(server_computation, (server_state, w_delta)) return (server_state, client_outputs.client_state)
def federated_train(model, learning_rate, data): l = tff.federated_map( local_train, [tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data]) return l
def federated_train(model, lr, data): #返回的是训练后的模型 return tff.federated_mean( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(lr), data ]))
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_optimizer_weights = tff.federated_broadcast( server_state.optimizer_state) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num, client_optimizer_weights)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean(client_outputs.weights_delta, weight=client_weight) implicit_mean_broad = tff.federated_broadcast( tff.federated_mean(client_outputs.weights, weight=client_weight)) @tff.tf_computation(model_weights_type.trainable, model_weights_type.trainable) def individual_drift_fn(client_model, implicit_mean): drifts = tf.nest.map_structure(lambda a, b: a - b, client_model, implicit_mean) return drifts drifts = tff.federated_map( individual_drift_fn, (client_outputs.weights, implicit_mean_broad)) @tff.tf_computation(model_weights_type.trainable) def individual_drift_norm_fn(delta_from_mean): drift_norm = tf.reduce_sum( tf.nest.flatten( tf.nest.map_structure(lambda a: tf.nn.l2_loss(a), delta_from_mean))) return drift_norm drift_norms = tff.federated_map(individual_drift_norm_fn, drifts) client_drift = tff.federated_mean(drift_norms, weight=client_weight) server_state = tff.federated_map( server_update_fn, (server_state, model_delta, client_drift)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Note that in addition to updating the server weights according to the client model weight deltas, we extract metrics (governed by the `monitor` attribute of the `client_lr_callback` and `server_lr_callback` attributes of the `server_state`) and use these to update the client learning rate callbacks. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation` before and during local client training. """ client_model = tff.federated_broadcast(server_state.model) client_lr = tff.federated_broadcast( server_state.client_lr_callback.learning_rate) if dataset_preprocess_comp is not None: federated_dataset = tff.federated_map(dataset_preprocess_comp, federated_dataset) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_lr)) client_weight = client_outputs.client_weight aggregated_gradients = tff.federated_mean( client_outputs.accumulated_gradients, weight=client_weight) initial_aggregated_outputs = dummy_model.federated_output_computation( client_outputs.initial_model_output) if isinstance(initial_aggregated_outputs.type_signature, tff.StructType): initial_aggregated_outputs = tff.federated_zip( initial_aggregated_outputs) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) client_monitor_value = initial_aggregated_outputs[client_monitor] server_monitor_value = initial_aggregated_outputs[server_monitor] server_state = tff.federated_map( server_update_fn, (server_state, aggregated_gradients, client_monitor_value, server_monitor_value)) result = collections.OrderedDict( before_training=initial_aggregated_outputs, during_training=aggregated_outputs) return server_state, result
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ # Run computation on the clients. client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) # Aggregate model deltas. client_weight = client_outputs.client_weight if mask_zeros_in_client_updates: model_delta = federated_mean_masked(client_outputs.weights_delta, client_weight) else: model_delta = tff.federated_mean(client_outputs.weights_delta, client_weight) server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) # Aggregate model outputs that contain local metrics and various statistics. aggregated_outputs = placeholder_model.federated_output_computation( client_outputs.model_output) additional_outputs = tff.federated_mean( client_outputs.additional_output, weight=client_weight) @tff.tf_computation(aggregated_outputs.type_signature.member, additional_outputs.type_signature.member) def _update_aggregated_outputs(aggregated_outputs, additional_outputs): aggregated_outputs.update(additional_outputs) return aggregated_outputs aggregated_outputs = tff.federated_map( _update_aggregated_outputs, (aggregated_outputs, additional_outputs)) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState` containing the state of the training process up to the current round. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS` containing data to train the current round on. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) weight_denom = client_outputs.client_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta)) aggregated_outputs = metrics_aggregation_computation( client_outputs.model_output) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) # Model deltas are equally weighted in DP. round_model_delta = tff.federated_mean(client_outputs.weights_delta) server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta)) round_loss_metric = tff.federated_mean(client_outputs.model_output) return server_state, round_loss_metric
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) round_grads_norm = tff.federated_mean( client_outputs.optimizer_output.flat_grads_norm_sum, weight=weight_denom) server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, round_grads_norm)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_apply(server_update_fn, (server_state, round_model_delta)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset, client_states): """Orchestration logic for one round of computation. Args: server_state: A `stateful_fedavg_tf.ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. client_states: A federated `stateful_fedavg_tf.ClientState`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_states, server_message_at_client)) weight_denom = client_outputs.client_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) total_iters_count = tff.federated_sum( client_outputs.client_state.iters_count) server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, total_iters_count)) round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom) return server_state, round_loss_metric, client_outputs.client_state
def validate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(validate_client_tf, (datasets, client_states, broadcast)) metrics = model.federated_output_computation(outputs.metrics) return metrics
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_optimizer_state = tff.federated_broadcast( server_state.client_optimizer_state) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_optimizer_state, client_round_num)) client_model_weight = client_outputs.client_weight.model_weight client_opt_weight = client_outputs.client_weight.optimizer_weight model_delta = tff.federated_mean( client_outputs.weights_delta, weight=client_model_weight) # We convert the optimizer state to a float type so that it can be used # with thing such as `tff.federated_mean`. This is only necessary because # `tf.keras.Optimizer` objects have a state with an integer indicating # the number of times it has been applied. client_optimizer_state_delta = tff.federated_map( _convert_opt_state_to_float, client_outputs.optimizer_state_delta) client_optimizer_state_delta = optimizer_aggregator( client_optimizer_state_delta, weight=client_opt_weight) # We conver the optimizer state back into one with an integer round number client_optimizer_state_delta = tff.federated_map( _convert_opt_state_to_int, client_optimizer_state_delta) server_state = tff.federated_map( server_update_fn, (server_state, model_delta, client_optimizer_state_delta)) aggregated_outputs = placeholder_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def train_one_round(model, federated_data): locally_trained_models = tff.federated_map( train_on_one_client, collections.OrderedDict([('model', tff.federated_broadcast(model)), ('batches', federated_data)])) return tff.federated_aggregate(locally_trained_models, make_zero_model_and_count(), accumulate, merge, report)
def comp(temperatures, threshold): return tff.federated_mean( tff.federated_map( count_over, tff.federated_zip( [temperatures, tff.federated_broadcast(threshold)])), tff.federated_map(count_total, temperatures))
def robust_aggregation_fn(value, weight): aggregate = tff.federated_mean(value, weight=weight) for _ in range(num_communication_passes - 1): aggregate_at_client = tff.federated_broadcast(aggregate) updated_weight = tff.federated_map( update_weight_fn, (weight, aggregate_at_client, value)) aggregate = tff.federated_mean(value, weight=updated_weight) return aggregate
def _robust_aggregation_fn(state, value, weight): aggregate = tff.federated_mean(value, weight=weight) for _ in range(num_communication_passes - 1): aggregate_at_client = tff.federated_broadcast(aggregate) updated_weight = tff.federated_map( update_weight_fn, (weight, aggregate_at_client, value)) aggregate = tff.federated_mean(value, weight=updated_weight) no_metrics = tff.federated_value((), tff.SERVER) return tff.templates.MeasuredProcessOutput(state, aggregate, no_metrics)
def evaluate(weights, datasets, client_states): broadcast = tff.federated_broadcast(weights) outputs = tff.federated_map(evaluate_client_tf, (datasets, client_states, broadcast)) confusion_matrix = tff.federated_sum(outputs.confusion_matrix) aggregated_metrics = model.federated_output_computation( outputs.metrics) collected_metrics = tff.federated_collect(outputs.metrics) return confusion_matrix, aggregated_metrics, collected_metrics
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and aggregated metrics. """ client_model = tff.federated_broadcast(server_state.model) client_round_number = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_number)) if len(aggregation_process.next.type_signature.parameter) == 3: # Weighted aggregation. aggregation_output = aggregation_process.next( server_state.aggregator_state, client_outputs.weights_delta, weight=client_outputs.client_weight) else: # Unweighted aggregation. aggregation_output = aggregation_process.next( server_state.aggregator_state, client_outputs.weights_delta) round_model_delta = aggregation_output.result server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, aggregation_output.state)) aggregated_model_outputs = federated_output_computation( client_outputs.model_output) # We drop the `measurements` portion of the aggregation_output here, as it # is not necessary for our experiments. return server_state, aggregated_model_outputs
def next_fn(server_weights, federated_dataset): # Broadcast the server weights to the clients. server_weights_at_client = tff.federated_broadcast(server_weights) # Each client computes their updated weights. client_weights = client_update(federated_dataset, server_weights_at_client) # The server averages these updates. mean_client_weights = np.mean(client_weights) # The server updates its model. server_weights = server_update(mean_client_weights) return server_weights
def encoded_broadcast_fn(state, value): """Broadcast function, to be wrapped as federated_computation.""" state_type = state.type_signature.member value_type = value.type_signature.member encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast( state_type, value_type, encoders) new_state, encoded_value = tff.federated_apply(encode_fn, (state, value)) client_encoded_value = tff.federated_broadcast(encoded_value) client_value = tff.federated_map(decode_fn, client_encoded_value) return new_state, client_value
def run_one_round(server_state, federated_dataset, malicious_dataset, malicious_clients): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. consisting of malicious datasets. malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, malicious_dataset, malicious_clients, client_model)) weight_denom = client_outputs.weights_delta_weight # If the aggregation process' next function takes three arguments it is # weighted, otherwise, unweighted. Unfortunately there is no better way # to determine this. if len(aggregation_process.next.type_signature.parameter) == 3: aggregate_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, weight=weight_denom) else: aggregate_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta) new_delta_aggregate_state = aggregate_output.state round_model_delta = aggregate_output.result server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, new_delta_aggregate_state)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def next_tff(server_state, datasets, client_states): message = tff.federated_map(state_to_message_tf, server_state) broadcast = tff.federated_broadcast(message) outputs = tff.federated_map(update_client_tf, (datasets, client_states, broadcast)) weights_delta = tff.federated_mean(outputs.weights_delta, weight=outputs.client_weight) metrics = model.federated_output_computation(outputs.metrics) next_state = tff.federated_map(update_server_tf, (server_state, weights_delta)) return next_state, metrics, outputs.client_state
def run_gradient_computation_round(server_state, federated_dataset): """Orchestration logic for one round of gradient computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `tf.Tensor` of clients initial probability and `ClientOutput`. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) update_norm_sum_weighted = tff.federated_sum( client_outputs.update_norm_weighted) norm_sum_clients_weighted = tff.federated_broadcast( update_norm_sum_weighted) prob_init = scale_on_clients(client_outputs.update_norm_weighted, norm_sum_clients_weighted) return prob_init, client_outputs
def next_fn(state, value): clip_range_lower, clip_range_upper = self._get_clip_range() # Modular clip values before aggregation. clipped_value = tff.federated_map( modular_clip_by_value_tff, (value, tff.federated_broadcast(clip_range_lower), tff.federated_broadcast(clip_range_upper))) (agg_output_state, agg_output_result, agg_output_measurements) = inner_agg_next(state, clipped_value) # Clip the aggregate to the same range again (not considering summands). clipped_agg_output_result = tff.federated_map( modular_clip_by_value_tff, (agg_output_result, clip_range_lower, clip_range_upper)) measurements = collections.OrderedDict( agg_process=agg_output_measurements) return tff.templates.MeasuredProcessOutput( state=agg_output_state, result=clipped_agg_output_result, measurements=tff.federated_zip(measurements))
def next_fn(server_weights, federated_dataset): # 将服务器模型广播到客户端上 server_weights_at_client = tff.federated_broadcast(server_weights) # 客户端计算更新过程,并更新参数 client_weights, clients_loss = tff.federated_map( client_update_fn, (federated_dataset, server_weights_at_client)) # 服务器平均所有客户端更新的模型参数 mean_client_weights = tff.federated_mean(client_weights) # 服务器更新他的模型 server_weights = tff.federated_map(server_update_fn, mean_client_weights) return server_weights, client_weights, clients_loss
def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" # TODO(b/131429028): The federated_zip should be automatic. from_server = tff.federated_zip( gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights)) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) if gan.dp_averaging_fn is None: # Not using differential privacy. new_dp_averaging_state = server_state.dp_averaging_state averaged_discriminator_weights_delta = tff.federated_mean( client_outputs.discriminator_weights_delta, weight=client_outputs.update_weight) else: # Using differential privacy. Note that the weight argument is set to None # here. This is because the DP aggregation code explicitly does not do # weighted aggregation. (If weighted aggregation is desired, differential # privacy needs to be turned off.) new_dp_averaging_state, averaged_discriminator_weights_delta = ( gan.dp_averaging_fn(server_state.dp_averaging_state, client_outputs.discriminator_weights_delta, weight=None)) # TODO(b/131085687): Perhaps reconsider the choice to also use # ClientOutput to hold the aggregated client output. aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight), counters=tff.federated_sum(client_outputs.counters)) # TODO(b/131839522): This federated_zip shouldn't be needed. aggregated_client_output = tff.federated_zip(aggregated_client_output) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_dp_averaging_state)) return server_state