def run_one_round(server_state, federated_dataset, client_states): """Orchestration logic for one round of computation. Args: server_state: A `stateful_fedavg_tf.ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. client_states: A federated `stateful_fedavg_tf.ClientState`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_states, server_message_at_client)) weight_denom = client_outputs.client_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) total_iters_count = tff.federated_sum( client_outputs.client_state.iters_count) server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, total_iters_count)) round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom) return server_state, round_loss_metric, client_outputs.client_state
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) round_grads_norm = tff.federated_mean( client_outputs.optimizer_output.flat_grads_norm_sum, weight=weight_denom) server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta, round_grads_norm)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, client_states, federated_dataset, federated_dataset_single): from_server = FromServer(w=server_state.w, meta_w=server_state.meta_w, round_num=server_state.round_num) from_server = tff.federated_broadcast(from_server) control_output = tff.federated_map( control_computation, (federated_dataset_single, client_state, from_server)) c = tff.federated_broadcast( tff.federated_mean(control_output.w_delta, weight=control_output.client_weight)) @tff.tf_computation(client_output_type.w_delta, client_output_type.w_delta) def compute_control_input(c, c_i): correction = tf.nest.map_structure(lambda a, b: a - b, c, c_i) # if we are using SCAFFOLD then use the correction, otherwise # we just let the correction be zero. return tf.cond( tf.constant(control, dtype=tf.bool), lambda: correction, lambda: tf.nest.map_structure(tf.zeros_like, correction)) # the collection of gradient corrections corrections = tff.federated_map(compute_control_input, (c, control_output.cs)) client_outputs = tff.federated_map( client_computation, (federated_dataset, from_server, client_state, corrections)) w_delta = tff.federated_mean(client_outputs.w_delta, weight=client_outputs.client_weight) server_state = tff.federated_map(server_computation, (server_state, w_delta)) return (server_state, client_outputs.client_state)
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) # Model deltas are equally weighted in DP. round_model_delta = tff.federated_mean(client_outputs.weights_delta) server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta)) round_loss_metric = tff.federated_mean(client_outputs.model_output) return server_state, round_loss_metric
def robust_aggregation_fn(value, weight): aggregate = tff.federated_mean(value, weight=weight) for _ in range(num_communication_passes - 1): aggregate_at_client = tff.federated_broadcast(aggregate) updated_weight = tff.federated_map( update_weight_fn, (weight, aggregate_at_client, value)) aggregate = tff.federated_mean(value, weight=updated_weight) return aggregate
def _robust_aggregation_fn(state, value, weight): aggregate = tff.federated_mean(value, weight=weight) for _ in range(num_communication_passes - 1): aggregate_at_client = tff.federated_broadcast(aggregate) updated_weight = tff.federated_map( update_weight_fn, (weight, aggregate_at_client, value)) aggregate = tff.federated_mean(value, weight=updated_weight) no_metrics = tff.federated_value((), tff.SERVER) return tff.templates.MeasuredProcessOutput(state, aggregate, no_metrics)
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_optimizer_weights = tff.federated_broadcast( server_state.optimizer_state) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num, client_optimizer_weights)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean(client_outputs.weights_delta, weight=client_weight) implicit_mean_broad = tff.federated_broadcast( tff.federated_mean(client_outputs.weights, weight=client_weight)) @tff.tf_computation(model_weights_type.trainable, model_weights_type.trainable) def individual_drift_fn(client_model, implicit_mean): drifts = tf.nest.map_structure(lambda a, b: a - b, client_model, implicit_mean) return drifts drifts = tff.federated_map( individual_drift_fn, (client_outputs.weights, implicit_mean_broad)) @tff.tf_computation(model_weights_type.trainable) def individual_drift_norm_fn(delta_from_mean): drift_norm = tf.reduce_sum( tf.nest.flatten( tf.nest.map_structure(lambda a: tf.nn.l2_loss(a), delta_from_mean))) return drift_norm drift_norms = tff.federated_map(individual_drift_norm_fn, drifts) client_drift = tff.federated_mean(drift_norms, weight=client_weight) server_state = tff.federated_map( server_update_fn, (server_state, model_delta, client_drift)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ # Run computation on the clients. client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) # Aggregate model deltas. client_weight = client_outputs.client_weight if mask_zeros_in_client_updates: model_delta = federated_mean_masked(client_outputs.weights_delta, client_weight) else: model_delta = tff.federated_mean(client_outputs.weights_delta, client_weight) server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) # Aggregate model outputs that contain local metrics and various statistics. aggregated_outputs = placeholder_model.federated_output_computation( client_outputs.model_output) additional_outputs = tff.federated_mean( client_outputs.additional_output, weight=client_weight) @tff.tf_computation(aggregated_outputs.type_signature.member, additional_outputs.type_signature.member) def _update_aggregated_outputs(aggregated_outputs, additional_outputs): aggregated_outputs.update(additional_outputs) return aggregated_outputs aggregated_outputs = tff.federated_map( _update_aggregated_outputs, (aggregated_outputs, additional_outputs)) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def federated_train(model, learning_rate, data, classes): return tff.federated_mean( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(learning_rate), data, tff.federated_broadcast(classes) ]))
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean( client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def next_fn(state, deltas, weights): @tff.tf_computation(model_update_type) def clip_by_global_norm(update): clipped_update, global_norm = tf.clip_by_global_norm( tf.nest.flatten(update), tf.constant(clip_norm)) was_clipped = tf.cond( tf.greater(global_norm, tf.constant(clip_norm)), lambda: tf.constant(1), lambda: tf.constant(0), ) clipped_update = tf.nest.pack_sequence_as(update, clipped_update) return clipped_update, global_norm, was_clipped clipped_deltas, client_norms, client_was_clipped = tff.federated_map( clip_by_global_norm, deltas) return collections.OrderedDict( state=state, result=tff.federated_mean(clipped_deltas, weight=weights), measurements=tff.federated_zip( NormClippedAggregationMetrics( max_global_norm=tff.utils.federated_max(client_norms), num_clipped=tff.federated_sum(client_was_clipped), )))
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) weight_denom = client_outputs.weights_delta_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_apply(server_update_fn, (server_state, round_model_delta)) aggregated_outputs = dummy_model_for_metadata.federated_output_computation( client_outputs.model_output) aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def aggregate_and_clip(global_state, value, weight=None): """Compute the weighted mean of values@CLIENTS and clip it and return an aggregated value @SERVER.""" round_model_delta = tff.federated_mean(value, weight) value_type = value.type_signature.member @tff.tf_computation(value_type._asdict()) @tf.function def clip_by_norm(gradient, norm=norm_bound): """Clip the gradient by a certain l_2 norm.""" delta_norm = tf.linalg.global_norm(tf.nest.flatten(gradient)) if delta_norm < tf.cast(norm, tf.float32): return gradient else: delta_mul_factor = tf.math.divide_no_nan( tf.cast(norm, tf.float32), delta_norm) nested_mul_factor = collections.OrderedDict([ (key, delta_mul_factor) for key in gradient.keys() ]) return tf.nest.map_structure(tf.multiply, nested_mul_factor, gradient) return global_state, tff.federated_map(clip_by_norm, round_model_delta)
def federated_train(model, lr, data): #返回的是训练后的模型 return tff.federated_mean( tff.federated_map(local_train, [ tff.federated_broadcast(model), tff.federated_broadcast(lr), data ]))
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState` containing the state of the training process up to the current round. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS` containing data to train the current round on. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) weight_denom = client_outputs.client_weight round_model_delta = tff.federated_mean(client_outputs.weights_delta, weight=weight_denom) server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta)) aggregated_outputs = metrics_aggregation_computation( client_outputs.model_output) return server_state, aggregated_outputs
def comp(temperatures, threshold): return tff.federated_mean( tff.federated_map( count_over, tff.federated_zip( [temperatures, tff.federated_broadcast(threshold)])), tff.federated_map(count_total, temperatures))
def next_fn(server_weights, federated_dataset): # Send server weights to clients server_weights_to_clients = tff.federated_broadcast(server_weights) # Each client computes their updated weights client_weights = tff.federated_map( client_update_fn, (federated_dataset, server_weights_to_clients)) # Client mean mean_client_weights = tff.federated_mean(client_weights) # Server averages all the client weights mean_client_weights = tff.federated_mean(client_weights) # The server updates it model server_weights = tff.federated_map(server_update_fn, mean_client_weights) return (server_weights, client_weights)
def federated_evaluate(model_weights, federated_dataset): client_model = tff.federated_broadcast(model_weights) client_metrics = tff.federated_map(compute_client_metrics, (client_model, federated_dataset)) # Extract the number of examples in order to compute client weights num_examples = client_metrics.num_examples uniform_weighted_metrics = tff.federated_mean(client_metrics, weight=None) example_weighted_metrics = tff.federated_mean(client_metrics, weight=num_examples) # Aggregate the metrics in a single nested dictionary aggregate_metrics = collections.OrderedDict() aggregate_metrics[AggregationMethods.EXAMPLE_WEIGHTED. value] = example_weighted_metrics aggregate_metrics[AggregationMethods.UNIFORM_WEIGHTED. value] = uniform_weighted_metrics return aggregate_metrics
def run_one_round(server_state, client_states): """Orchestration logic for one round of federated training computation.""" # performing the federated averaging of the clients' weights mean_client_weights = tff.federated_mean(client_states) print(str(mean_client_weights)) ## SERVER UPDATING STEP server_state = tff.federated_apply(server_update_fn, (server_state, mean_client_weights)) # returning the new server state return server_state
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Note that in addition to updating the server weights according to the client model weight deltas, we extract metrics (governed by the `monitor` attribute of the `client_lr_callback` and `server_lr_callback` attributes of the `server_state`) and use these to update the client learning rate callbacks. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation` before and during local client training. """ client_model = tff.federated_broadcast(server_state.model) client_lr = tff.federated_broadcast( server_state.client_lr_callback.learning_rate) if dataset_preprocess_comp is not None: federated_dataset = tff.federated_map(dataset_preprocess_comp, federated_dataset) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_lr)) client_weight = client_outputs.client_weight aggregated_gradients = tff.federated_mean( client_outputs.accumulated_gradients, weight=client_weight) initial_aggregated_outputs = dummy_model.federated_output_computation( client_outputs.initial_model_output) if isinstance(initial_aggregated_outputs.type_signature, tff.StructType): initial_aggregated_outputs = tff.federated_zip( initial_aggregated_outputs) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if isinstance(aggregated_outputs.type_signature, tff.StructType): aggregated_outputs = tff.federated_zip(aggregated_outputs) client_monitor_value = initial_aggregated_outputs[client_monitor] server_monitor_value = initial_aggregated_outputs[server_monitor] server_state = tff.federated_map( server_update_fn, (server_state, aggregated_gradients, client_monitor_value, server_monitor_value)) result = collections.OrderedDict( before_training=initial_aggregated_outputs, during_training=aggregated_outputs) return server_state, result
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_outputs = tff.federated_map(client_update_fn, (federated_dataset, client_model)) client_weight = client_outputs.client_weight model_delta = tff.federated_mean(client_outputs.weights_delta, weight=client_weight) global_cor_states = tff.federated_mean(client_outputs.local_cor_states, weight=client_weight) if correction == 'joint': server_state = tff.federated_map( server_update_joint_cor_fn, (server_state, model_delta, global_cor_states)) elif correction == 'local': server_state = tff.federated_map(server_update_fn, (server_state, model_delta)) else: raise TypeError('Correction method must be local or joint.') aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def aggregate_metrics_across_clients(metrics): global metrics_name output = collections.OrderedDict() for metric in metrics_name: if metric == 'num_examples': output[metric] = tff.federated_sum(getattr(metrics, metric)) output['per_client/' + metric] = tff.federated_collect( getattr(metrics, metric)) else: output[metric] = tff.federated_mean(getattr(metrics, metric), metrics.num_examples) output['per_client/' + metric] = tff.federated_collect( getattr(metrics, metric)) return output
def run_one_round(server_state, client_states, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ # Prepare server_message to be sent to the clients, # based on the server_state from previous round server_message = tff.federated_map(server_message_fn, server_state) # Update the clients with the new server_message and dataset client_outputs, new_client_state = tff.federated_map( client_update_fn, ( federated_dataset, tff.federated_broadcast(server_message), client_states, ) ) round_model_delta = tff.federated_mean( client_outputs.weights_delta, weight=client_outputs.client_weight) # Update server state given the current round's completion server_state = tff.federated_map( server_update_fn, (server_state, round_model_delta)) round_loss_metric = tff.federated_mean( client_outputs.model_output, weight=client_outputs.client_weight) return server_state, new_client_state, round_loss_metric
def next_fn(server_weights, federated_dataset): # 将服务器模型广播到客户端上 server_weights_at_client = tff.federated_broadcast(server_weights) # 客户端计算更新过程,并更新参数 client_weights, clients_loss = tff.federated_map( client_update_fn, (federated_dataset, server_weights_at_client)) # 服务器平均所有客户端更新的模型参数 mean_client_weights = tff.federated_mean(client_weights) # 服务器更新他的模型 server_weights = tff.federated_map(server_update_fn, mean_client_weights) return server_weights, client_weights, clients_loss
def next_tff(server_state, datasets, client_states): message = tff.federated_map(state_to_message_tf, server_state) broadcast = tff.federated_broadcast(message) outputs = tff.federated_map(update_client_tf, (datasets, client_states, broadcast)) weights_delta = tff.federated_mean(outputs.weights_delta, weight=outputs.client_weight) metrics = model.federated_output_computation(outputs.metrics) next_state = tff.federated_map(update_server_tf, (server_state, weights_delta)) return next_state, metrics, outputs.client_state
def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" # TODO(b/131429028): The federated_zip should be automatic. from_server = tff.federated_zip( gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights)) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) if gan.dp_averaging_fn is None: # Not using differential privacy. new_dp_averaging_state = server_state.dp_averaging_state averaged_discriminator_weights_delta = tff.federated_mean( client_outputs.discriminator_weights_delta, weight=client_outputs.update_weight) else: # Using differential privacy. Note that the weight argument is set to None # here. This is because the DP aggregation code explicitly does not do # weighted aggregation. (If weighted aggregation is desired, differential # privacy needs to be turned off.) new_dp_averaging_state, averaged_discriminator_weights_delta = ( gan.dp_averaging_fn(server_state.dp_averaging_state, client_outputs.discriminator_weights_delta, weight=None)) # TODO(b/131085687): Perhaps reconsider the choice to also use # ClientOutput to hold the aggregated client output. aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight), counters=tff.federated_sum(client_outputs.counters)) # TODO(b/131839522): This federated_zip shouldn't be needed. aggregated_client_output = tff.federated_zip(aggregated_client_output) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_dp_averaging_state)) return server_state
def next_fn(server_weights, federated_dataset): print('sw', dir(server_weights)) # Broadcast the server weights to the clients. server_weights_at_client = tff.federated_broadcast(server_weights) print(server_weights_at_client) # Each client computes their updated weights. client_weights = tff.federated_map( client_update_fn, (federated_dataset, server_weights_at_client)) print('weights', client_weights) # The server averages these updates. mean_client_weights = tff.federated_mean(client_weights) # The server updates its model. server_weights = tff.federated_map(server_update_fn, mean_client_weights) return server_weights
def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_optimizer_state = tff.federated_broadcast( server_state.client_optimizer_state) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_optimizer_state, client_round_num)) client_model_weight = client_outputs.client_weight.model_weight client_opt_weight = client_outputs.client_weight.optimizer_weight model_delta = tff.federated_mean( client_outputs.weights_delta, weight=client_model_weight) # We convert the optimizer state to a float type so that it can be used # with thing such as `tff.federated_mean`. This is only necessary because # `tf.keras.Optimizer` objects have a state with an integer indicating # the number of times it has been applied. client_optimizer_state_delta = tff.federated_map( _convert_opt_state_to_float, client_outputs.optimizer_state_delta) client_optimizer_state_delta = optimizer_aggregator( client_optimizer_state_delta, weight=client_opt_weight) # We conver the optimizer state back into one with an integer round number client_optimizer_state_delta = tff.federated_map( _convert_opt_state_to_int, client_optimizer_state_delta) server_state = tff.federated_map( server_update_fn, (server_state, model_delta, client_optimizer_state_delta)) aggregated_outputs = placeholder_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" from_server = gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) if gan.dp_averaging_fn is None: # Not using differential privacy. new_dp_averaging_state = server_state.dp_averaging_state averaged_discriminator_weights_delta = tff.federated_mean( client_outputs.discriminator_weights_delta, weight=client_outputs.update_weight) else: # Using differential privacy. Note that the weight argument is set to # a constant 1.0 here, however the underlying AggregationProcess ignores # the parameter and performs no weighting. ignored_weight = tff.federated_value(1.0, tff.CLIENTS) aggregation_output = gan.dp_averaging_fn.next( server_state.dp_averaging_state, client_outputs.discriminator_weights_delta, weight=ignored_weight) new_dp_averaging_state = aggregation_output.state averaged_discriminator_weights_delta = aggregation_output.result # TODO(b/131085687): Perhaps reconsider the choice to also use # ClientOutput to hold the aggregated client output. aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight), counters=tff.federated_sum(client_outputs.counters)) server_computation = build_server_computation( gan, server_state.type_signature.member, client_output_type) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_dp_averaging_state)) return server_state
def next_fn(state, deltas, weights=None): @tff.tf_computation(deltas.type_signature.member, tf.float32) def clip_by_global_norm(delta, clip_norm): clipped, global_norm = tf.clip_by_global_norm( tf.nest.flatten(delta), clip_norm) clipped_deltas = tf.nest.pack_sequence_as(delta, clipped) return clipped_deltas, global_norm client_clip_norm = tff.federated_broadcast(state.clip_norm) clipped_deltas, client_norms = tff.federated_map( clip_by_global_norm, (deltas, client_clip_norm)) # clip_norm no-op update here but could be set using max_norm. next_state = tff.federated_zip( ClipNormAggregateState( clip_norm=state.clip_norm, max_norm=tff.utils.federated_max(client_norms))) return next_state, tff.federated_mean(clipped_deltas, weight=weights)