def client_computation(gen_inputs, real_data, from_server): """Returns the client_output.""" return gan_training_tf_fns.client_computation( gen_inputs_ds=gen_inputs, real_data_ds=real_data, from_server=from_server, generator=gan.generator_model_fn(), discriminator=gan.discriminator_model_fn(), train_discriminator_fn=gan.train_discriminator_fn)
def client_computation(gen_inputs, real_data, from_server, control_input_gen, control_input_disc): """Returns the client_output.""" return gan_training_tf_fns.client_computation( gen_inputs_ds=gen_inputs, real_data_ds=real_data, from_server=from_server, generator=gan.generator_model_fn(), discriminator=gan.discriminator_model_fn(), gen_optimizer=gen_optimizer_fn(), disc_optimizer=disc_optimizer_fn(), control_input_gen=control_input_gen, control_input_disc=control_input_disc, tau=tau)
def test_client_and_server_computations(self): train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) # N.B. The way we are using datasets and re-using the same # generator and discriminator doesn't really "make sense" from an ML # perspective, but it's sufficient for testing. For more proper usage of # these functions, see training_loops.py. generator = one_dim_gan.create_generator() discriminator = one_dim_gan.create_discriminator() gen_inputs = one_dim_gan.create_generator_inputs() real_data = one_dim_gan.create_real_data() server_state = gan_training_tf_fns.server_initial_state( generator, discriminator) # The aggregation state (e.g., used for handling DP averaging) is # initialized to be empty. A user of the `server_initial_state` is expected # to take the output `ServerState` object and populate this field, most # likely via an instance of tff.templates.AggregationProcess. self.assertEmpty(server_state.aggregation_state) client_output = gan_training_tf_fns.client_computation( gen_inputs.take(3), real_data.take(3), gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights), generator, discriminator, train_discriminator_fn) server_disc_update_optimizer = tf.keras.optimizers.Adam() for _ in range(2): # Train for 2 rounds server_state = gan_training_tf_fns.server_computation( server_state, gen_inputs.take(3), client_output, generator, discriminator, server_disc_update_optimizer, train_generator_fn, NEW_DP_AVERAGING_STATE) counters = self.evaluate(server_state.counters) self.assertDictEqual( counters, { 'num_rounds': 2, 'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE, 'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE }) # DP averaging aggregation state updates properly in server_computation(). self.assertEqual(server_state.aggregation_state, NEW_DP_AVERAGING_STATE)
def test_client_and_server_computations(self): train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) # N.B. The way we are using datasets and re-using the same # generator and discriminator doesn't really "make sense" from an ML # perspective, but it's sufficient for testing. For more proper usage of # these functions, see training_loops.py. generator = one_dim_gan.create_generator() discriminator = one_dim_gan.create_discriminator() gen_inputs = one_dim_gan.create_generator_inputs() real_data = one_dim_gan.create_real_data() server_state = gan_training_tf_fns.server_initial_state( generator, discriminator, INIT_DP_AVERAGING_STATE) # DP averaging aggregation state is initialized properly in # server_initial_state(). self.assertEqual(server_state.dp_averaging_state, INIT_DP_AVERAGING_STATE) client_output = gan_training_tf_fns.client_computation( gen_inputs.take(3), real_data.take(3), gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights), generator, discriminator, train_discriminator_fn) server_disc_update_optimizer = tf.keras.optimizers.Adam() for _ in range(2): # Train for 2 rounds server_state = gan_training_tf_fns.server_computation( server_state, gen_inputs.take(3), client_output, generator, discriminator, server_disc_update_optimizer, train_generator_fn, NEW_DP_AVERAGING_STATE) counters = self.evaluate(server_state.counters) self.assertDictEqual( counters, { 'num_rounds': 2, 'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE, 'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE }) # DP averaging aggregation state updates properly in server_computation(). self.assertEqual(server_state.dp_averaging_state, NEW_DP_AVERAGING_STATE)
def simple_training_loop(generator_model_fn, discriminator_model_fn, real_data_fn, gen_inputs_fn, train_generator_fn, train_discriminator_fn, total_rounds=30, client_disc_train_steps=16, server_gen_train_steps=8, rounds_per_eval=10, eval_hook=lambda *args: None): """Trains in TF using client_computation and server_computation. This is not intended to be a general-purpose training loop (e.g., the optimizers are hard-coded), it is primarily intended for testing. Args: generator_model_fn: A no-arg function return the generator model. discriminator_model_fn: A no-arg function return the discriminator model. real_data_fn: A no-arg function returning a dataset of real data batches. gen_inputs_fn: A no-arg function returning a dataset of generator input batches. train_generator_fn: A function which takes the two networks and generator input and trains the generator. train_discriminator_fn: A function which takes the two networks, generator input, and real data and trains the discriminator. total_rounds: Number of rounds to train. client_disc_train_steps: Number of discriminator training batches per round. server_gen_train_steps: Number of generator training batches per round. rounds_per_eval: How often to call the `eval_hook` function. eval_hook: A function taking arguments (generator, discriminator, server_state, round_num) and performs evaluation. Optional. Returns: A tuple (final `ServerState`, train_time_in_seconds). """ logging.info('Starting simple_training_loop') # N.B. We can't use real_data.take(...) in the loops below, # or we would get the same examples on every round. Using window # essentially breaks one Dataset into a sequence of Datasets, # which is exactly what we need here. client_gen_inputs = iter(gen_inputs_fn().window(client_disc_train_steps)) client_real_data = iter(real_data_fn().window(client_disc_train_steps)) server_gen_inputs = iter(gen_inputs_fn().window(server_gen_train_steps)) server_generator = generator_model_fn() server_discriminator = discriminator_model_fn() # We could probably use a single copy of the generator and discriminator, but # using separate copies is more faithful to how this code will be used in TFF. client_generator = generator_model_fn() client_discriminator = discriminator_model_fn() server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1.0) server_state = gan_training_tf_fns.server_initial_state( server_generator, server_discriminator) start_time = time.time() def do_eval(round_num): eval_hook(server_generator, server_discriminator, server_state, round_num) elapsed_minutes = (time.time() - start_time) / 60 print('Total training time {:.2f} minutes for {} rounds ' '({:.2f} rounds per minute)'.format(elapsed_minutes, round_num, round_num / elapsed_minutes), flush=True) logging.info('Starting training') for round_num in range(total_rounds): if round_num % rounds_per_eval == 0: do_eval(round_num) client_output = gan_training_tf_fns.client_computation( gen_inputs_ds=next(client_gen_inputs), real_data_ds=next(client_real_data), from_server=gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights), generator=client_generator, discriminator=client_discriminator, train_discriminator_fn=train_discriminator_fn) server_state = gan_training_tf_fns.server_computation( server_state=server_state, gen_inputs_ds=next(server_gen_inputs), client_output=client_output, generator=server_generator, discriminator=server_discriminator, server_disc_update_optimizer=server_disc_update_optimizer, train_generator_fn=train_generator_fn) train_time = time.time() - start_time do_eval(total_rounds) return server_state, train_time