def _get_gan(gen_model_fn, disc_model_fn, gan_loss_fns, gen_optimizer, disc_optimizer, server_gen_inputs_dataset, client_real_images_tff_data, use_dp, dp_l2_norm_clip, dp_noise_multiplier, clients_per_round): """Construct instance of tff_gans.GanFnsAndTypes class.""" dummy_gen_input = next(iter(server_gen_inputs_dataset)) dummy_real_data = next( iter( client_real_images_tff_data.create_tf_dataset_for_client( client_real_images_tff_data.client_ids[0]))) train_generator_fn = gan_training_tf_fns.create_train_generator_fn( gan_loss_fns, gen_optimizer) train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn( gan_loss_fns, disc_optimizer) dp_average_query = None if use_dp: dp_average_query = tensorflow_privacy.GaussianAverageQuery( l2_norm_clip=dp_l2_norm_clip, sum_stddev=dp_l2_norm_clip * dp_noise_multiplier, denominator=clients_per_round) return tff_gans.GanFnsAndTypes( generator_model_fn=gen_model_fn, discriminator_model_fn=disc_model_fn, dummy_gen_input=dummy_gen_input, dummy_real_data=dummy_real_data, train_generator_fn=train_generator_fn, train_discriminator_fn=train_discriminator_fn, server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=1.0 ), train_discriminator_dp_average_query=dp_average_query)
def _get_gan(with_dp=False): gan_loss_fns = gan_losses.get_gan_loss_fns('wasserstein') server_gen_optimizer = tf.keras.optimizers.Adam() client_disc_optimizer = tf.keras.optimizers.Adam() train_generator_fn = gan_training_tf_fns.create_train_generator_fn( gan_loss_fns, server_gen_optimizer) train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn( gan_loss_fns, client_disc_optimizer) if with_dp: dp_average_query = tensorflow_privacy.QuantileAdaptiveClipAverageQuery( initial_l2_norm_clip=BEFORE_DP_L2_NORM_CLIP, noise_multiplier=0.3, target_unclipped_quantile=3, learning_rate=0.1, clipped_count_stddev=0.0, expected_num_records=10, denominator=10.0) else: dp_average_query = None return tff_gans.GanFnsAndTypes( generator_model_fn=one_dim_gan.create_generator, discriminator_model_fn=one_dim_gan.create_discriminator, dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())), dummy_real_data=next(iter(one_dim_gan.create_real_data())), train_generator_fn=train_generator_fn, train_discriminator_fn=train_discriminator_fn, server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=1.0 ), train_discriminator_dp_average_query=dp_average_query)
def test_tff_training_loop(self, dp_average_query, checkpoint): if checkpoint: root_checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints') else: root_checkpoint_dir = None train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) gan = tff_gans.GanFnsAndTypes( generator_model_fn=one_dim_gan.create_generator, discriminator_model_fn=one_dim_gan.create_discriminator, dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())), dummy_real_data=next(iter(one_dim_gan.create_real_data())), train_generator_fn=train_generator_fn, train_discriminator_fn=train_discriminator_fn, server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD( lr=1.0), train_discriminator_dp_average_query=dp_average_query) gen_inputs = one_dim_gan.create_generator_inputs() real_data = one_dim_gan.create_real_data() client_disc_train_steps = 2 server_gen_train_steps = 3 server_gen_inputs = iter(gen_inputs.window(server_gen_train_steps)) client_gen_inputs = iter(gen_inputs.window(client_disc_train_steps)) client_real_data = iter(real_data.window(client_disc_train_steps)) def server_gen_inputs_fn(_): return next(server_gen_inputs) num_clients = 2 def client_datasets_fn(_): return [(next(client_gen_inputs), next(client_real_data)) for _ in range(num_clients)] server_state, _ = training_loops.federated_training_loop( gan, server_gen_inputs_fn=server_gen_inputs_fn, client_datasets_fn=client_datasets_fn, total_rounds=2, rounds_per_checkpoint=1, root_checkpoint_dir=root_checkpoint_dir) self.assertDictEqual( server_state.counters, { 'num_rounds': 2, 'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE, 'num_discriminator_train_examples': (2 * 2 * one_dim_gan.BATCH_SIZE * num_clients) }) if checkpoint: # TODO(b/141112101): We shouldn't need to re-create the gan, should be # able to reuse the instance from above. See comment inside tff_gans.py. train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) gan = tff_gans.GanFnsAndTypes( generator_model_fn=one_dim_gan.create_generator, discriminator_model_fn=one_dim_gan.create_discriminator, dummy_gen_input=next( iter(one_dim_gan.create_generator_inputs())), dummy_real_data=next(iter(one_dim_gan.create_real_data())), train_generator_fn=train_generator_fn, train_discriminator_fn=train_discriminator_fn, server_disc_update_optimizer_fn=lambda: tf.keras.optimizers. SGD(lr=1.0), train_discriminator_dp_average_query=dp_average_query) # Train one more round, which should resume from the checkpoint. server_state, _ = training_loops.federated_training_loop( gan, server_gen_inputs_fn=server_gen_inputs_fn, client_datasets_fn=client_datasets_fn, total_rounds=3, rounds_per_checkpoint=1, root_checkpoint_dir=root_checkpoint_dir) # Note: It would be better to return something from # federated_training_loop indicating the number of rounds trained in this # invocation, so we could verify the checkpoint was read. self.assertDictEqual( server_state.counters, { 'num_rounds': 3, 'num_generator_train_examples': 3 * 3 * one_dim_gan.BATCH_SIZE, 'num_discriminator_train_examples': (3 * 2 * one_dim_gan.BATCH_SIZE * num_clients) })