def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" # TODO(b/131429028): The federated_zip should be automatic. from_server = tff.federated_zip( gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights)) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) if gan.dp_averaging_fn is None: # Not using differential privacy. new_dp_averaging_state = server_state.dp_averaging_state averaged_discriminator_weights_delta = tff.federated_mean( client_outputs.discriminator_weights_delta, weight=client_outputs.update_weight) else: # Using differential privacy. Note that the weight argument is set to None # here. This is because the DP aggregation code explicitly does not do # weighted aggregation. (If weighted aggregation is desired, differential # privacy needs to be turned off.) new_dp_averaging_state, averaged_discriminator_weights_delta = ( gan.dp_averaging_fn(server_state.dp_averaging_state, client_outputs.discriminator_weights_delta, weight=None)) # TODO(b/131085687): Perhaps reconsider the choice to also use # ClientOutput to hold the aggregated client output. aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight), counters=tff.federated_sum(client_outputs.counters)) # TODO(b/131839522): This federated_zip shouldn't be needed. aggregated_client_output = tff.federated_zip(aggregated_client_output) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_dp_averaging_state)) return server_state
def __attrs_post_init__(self): self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input) self.real_data_type = tensor_spec_for_batch(self.dummy_real_data) # Model-weights based types self._generator = self.generator_model_fn() _ = self._generator(self.dummy_gen_input) py_typecheck.check_type(self._generator, tf.keras.models.Model) self._discriminator = self.discriminator_model_fn() _ = self._discriminator(self.dummy_real_data) py_typecheck.check_type(self._discriminator, tf.keras.models.Model) def vars_to_type(var_struct): # TODO(b/131681951): read_value() shouldn't be needed return tf.nest.map_structure( lambda v: tf.TensorSpec.from_tensor(v.read_value()), var_struct) self.discriminator_weights_type = vars_to_type( self._discriminator.weights) self.generator_weights_type = vars_to_type(self._generator.weights) self.from_server_type = gan_training_tf_fns.FromServer( generator_weights=self.generator_weights_type, discriminator_weights=self.discriminator_weights_type) self.client_gen_input_type = tff.FederatedType( tff.SequenceType(self.gen_input_type), tff.CLIENTS) self.client_real_data_type = tff.FederatedType( tff.SequenceType(self.real_data_type), tff.CLIENTS) self.server_gen_input_type = tff.FederatedType( tff.SequenceType(self.gen_input_type), tff.SERVER) # Right now, the logic in this library is effectively "if DP use stateful # aggregator, else don't use stateful aggregator". An alternative # formulation would be to always use a stateful aggregator, but when not # using DP default the aggregator to be a stateless mean, e.g., # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283. # This change will be easier to make if the tff.StatefulAggregateFn is # modified to have a property that gives the type of the aggregation state # (i.e., what we're storing in self.dp_averaging_state_type). if self.train_discriminator_dp_average_query is not None: self.dp_averaging_fn, self.dp_averaging_state_type = ( tff.utils.build_dp_aggregate( query=self.train_discriminator_dp_average_query, value_type_fn=lambda value: self. discriminator_weights_type, from_tff_result_fn=lambda record: list(record))) # pylint: disable=unnecessary-lambda
def test_client_computation(self, with_dp): gan = _get_gan(with_dp) client_comp = tff_gans.build_client_computation(gan) generator = gan.generator_model_fn() discriminator = gan.discriminator_model_fn() from_server = gan_training_tf_fns.FromServer( generator_weights=generator.weights, discriminator_weights=discriminator.weights) client_output = client_comp( one_dim_gan.create_generator_inputs().take(10), one_dim_gan.create_real_data().take(10), from_server) self.assertDictEqual( client_output.counters, {'num_discriminator_train_examples': 10 * one_dim_gan.BATCH_SIZE})
def test_client_and_server_computations(self): train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) # N.B. The way we are using datasets and re-using the same # generator and discriminator doesn't really "make sense" from an ML # perspective, but it's sufficient for testing. For more proper usage of # these functions, see training_loops.py. generator = one_dim_gan.create_generator() discriminator = one_dim_gan.create_discriminator() gen_inputs = one_dim_gan.create_generator_inputs() real_data = one_dim_gan.create_real_data() server_state = gan_training_tf_fns.server_initial_state( generator, discriminator, INIT_DP_AVERAGING_STATE) # DP averaging aggregation state is initialized properly in # server_initial_state(). self.assertEqual(server_state.dp_averaging_state, INIT_DP_AVERAGING_STATE) client_output = gan_training_tf_fns.client_computation( gen_inputs.take(3), real_data.take(3), gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights), generator, discriminator, train_discriminator_fn) server_disc_update_optimizer = tf.keras.optimizers.Adam() for _ in range(2): # Train for 2 rounds server_state = gan_training_tf_fns.server_computation( server_state, gen_inputs.take(3), client_output, generator, discriminator, server_disc_update_optimizer, train_generator_fn, NEW_DP_AVERAGING_STATE) counters = self.evaluate(server_state.counters) self.assertDictEqual( counters, { 'num_rounds': 2, 'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE, 'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE }) # DP averaging aggregation state updates properly in server_computation(). self.assertEqual(server_state.dp_averaging_state, NEW_DP_AVERAGING_STATE)
def simple_training_loop(generator_model_fn, discriminator_model_fn, real_data_fn, gen_inputs_fn, train_generator_fn, train_discriminator_fn, total_rounds=30, client_disc_train_steps=16, server_gen_train_steps=8, rounds_per_eval=10, eval_hook=lambda *args: None): """Trains in TF using client_computation and server_computation. This is not intended to be a general-purpose training loop (e.g., the optimizers are hard-coded), it is primarily intended for testing. Args: generator_model_fn: A no-arg function return the generator model. discriminator_model_fn: A no-arg function return the discriminator model. real_data_fn: A no-arg function returning a dataset of real data batches. gen_inputs_fn: A no-arg function returning a dataset of generator input batches. train_generator_fn: A function which takes the two networks and generator input and trains the generator. train_discriminator_fn: A function which takes the two networks, generator input, and real data and trains the discriminator. total_rounds: Number of rounds to train. client_disc_train_steps: Number of discriminator training batches per round. server_gen_train_steps: Number of generator training batches per round. rounds_per_eval: How often to call the `eval_hook` function. eval_hook: A function taking arguments (generator, discriminator, server_state, round_num) and performs evaluation. Optional. Returns: A tuple (final `ServerState`, train_time_in_seconds). """ logging.info('Starting simple_training_loop') # N.B. We can't use real_data.take(...) in the loops below, # or we would get the same examples on every round. Using window # essentially breaks one Dataset into a sequence of Datasets, # which is exactly what we need here. client_gen_inputs = iter(gen_inputs_fn().window(client_disc_train_steps)) client_real_data = iter(real_data_fn().window(client_disc_train_steps)) server_gen_inputs = iter(gen_inputs_fn().window(server_gen_train_steps)) server_generator = generator_model_fn() server_discriminator = discriminator_model_fn() # We could probably use a single copy of the generator and discriminator, but # using separate copies is more faithful to how this code will be used in TFF. client_generator = generator_model_fn() client_discriminator = discriminator_model_fn() server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1.0) server_state = gan_training_tf_fns.server_initial_state( server_generator, server_discriminator) start_time = time.time() def do_eval(round_num): eval_hook(server_generator, server_discriminator, server_state, round_num) elapsed_minutes = (time.time() - start_time) / 60 print( 'Total training time {:.2f} minutes for {} rounds ' '({:.2f} rounds per minute)'.format(elapsed_minutes, round_num, round_num / elapsed_minutes), flush=True) logging.info('Starting training') for round_num in range(total_rounds): if round_num % rounds_per_eval == 0: do_eval(round_num) client_output = gan_training_tf_fns.client_computation( gen_inputs_ds=next(client_gen_inputs), real_data_ds=next(client_real_data), from_server=gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights), generator=client_generator, discriminator=client_discriminator, train_discriminator_fn=train_discriminator_fn) server_state = gan_training_tf_fns.server_computation( server_state=server_state, gen_inputs_ds=next(server_gen_inputs), client_output=client_output, generator=server_generator, discriminator=server_discriminator, server_disc_update_optimizer=server_disc_update_optimizer, train_generator_fn=train_generator_fn) train_time = time.time() - start_time do_eval(total_rounds) return server_state, train_time
def run_one_round(server_state, server_gen_inputs, client_gen_inputs, client_real_data): """The `tff.Computation` to be returned.""" from_server = gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights, state_gen_optimizer_weights=server_state. state_gen_optimizer_weights, state_disc_optimizer_weights=server_state. state_disc_optimizer_weights, counters=server_state.counters) client_input = tff.federated_broadcast(from_server) client_outputs = tff.federated_map( client_computation, (client_gen_inputs, client_real_data, client_input)) if gan.dp_averaging_fn is None: # Not using differential privacy. new_dp_averaging_state = server_state.dp_averaging_state averaged_discriminator_weights_delta = tff.federated_mean( client_outputs.discriminator_weights_delta, weight=client_outputs.update_weight_G) averaged_generator_weights_delta = tff.federated_mean( client_outputs.generator_weights_delta, weight=client_outputs.update_weight_G) averaged_gen_opt_delta = tff.federated_mean( client_outputs.state_gen_opt_delta, weight=client_outputs.update_weight_G) averaged_disc_opt_delta = tff.federated_mean( client_outputs.state_disc_opt_delta, weight=client_outputs.update_weight_G) else: # Using differential privacy. Note that the weight argument is set to None # here. This is because the DP aggregation code explicitly does not do # weighted aggregation. (If weighted aggregation is desired, differential # privacy needs to be turned off.) new_dp_averaging_state, averaged_discriminator_weights_delta = ( gan.dp_averaging_fn(server_state.dp_averaging_state, client_outputs.discriminator_weights_delta, weight=None)) aggregated_client_output = gan_training_tf_fns.ClientOutput( discriminator_weights_delta=averaged_discriminator_weights_delta, generator_weights_delta=averaged_generator_weights_delta, state_gen_opt_delta=averaged_gen_opt_delta, state_disc_opt_delta=averaged_disc_opt_delta, # We don't actually need the aggregated update_weight, but # this keeps the types of the non-aggregated and aggregated # client_output the same, which is convenient. And I can # imagine wanting this. update_weight=tff.federated_sum(client_outputs.update_weight_D), update_weight_D=tff.federated_sum(client_outputs.update_weight_D), update_weight_G=tff.federated_sum(client_outputs.update_weight_G), counters=tff.federated_sum(client_outputs.counters)) server_state = tff.federated_map( server_computation, (server_state, server_gen_inputs, aggregated_client_output, new_dp_averaging_state)) return server_state