def client_output_fn(): discriminator = gan.discriminator_model_fn() return gan_training_tf_fns.ClientOutput( discriminator_weights_delta=[ tf.zeros(shape=v.shape, dtype=v.dtype) for v in discriminator.weights ], update_weight=1.0, counters={'num_discriminator_train_examples': 13})
def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" from_server = gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) if gan.dp_averaging_fn is None: # Not using differential privacy. new_dp_averaging_state = server_state.dp_averaging_state averaged_discriminator_weights_delta = tff.federated_mean( client_outputs.discriminator_weights_delta, weight=client_outputs.update_weight) else: # Using differential privacy. Note that the weight argument is set to # a constant 1.0 here, however the underlying AggregationProcess ignores # the parameter and performs no weighting. ignored_weight = tff.federated_value(1.0, tff.CLIENTS) aggregation_output = gan.dp_averaging_fn.next( server_state.dp_averaging_state, client_outputs.discriminator_weights_delta, weight=ignored_weight) new_dp_averaging_state = aggregation_output.state averaged_discriminator_weights_delta = aggregation_output.result # TODO(b/131085687): Perhaps reconsider the choice to also use # ClientOutput to hold the aggregated client output. aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight), counters=tff.federated_sum(client_outputs.counters)) server_computation = build_server_computation( gan, server_state.type_signature.member, client_output_type) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_dp_averaging_state)) return server_state
def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" from_server = gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) # Note that weight goes unused here if the aggregation is involving # Differential Privacy; the underlying AggregationProcess doesn't take the # parameter, as it just uniformly weights the clients. if gan.aggregation_process.is_weighted: aggregation_output = gan.aggregation_process.next( server_state.aggregation_state, client_outputs.discriminator_weights_delta, client_outputs.update_weight) else: aggregation_output = gan.aggregation_process.next( server_state.aggregation_state, client_outputs.discriminator_weights_delta) new_aggregation_state = aggregation_output.state averaged_discriminator_weights_delta = aggregation_output.result # TODO(b/131085687): Perhaps reconsider the choice to also use # ClientOutput to hold the aggregated client output. aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight), counters=tff.federated_sum(client_outputs.counters)) server_computation = build_server_computation( gan, server_state.type_signature.member, client_output_type, gan.aggregation_process.state_type.member) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_aggregation_state)) return server_state