def _get_gan(with_dp=False): gan_loss_fns = gan_losses.get_gan_loss_fns('wasserstein') server_gen_optimizer = tf.keras.optimizers.Adam() client_disc_optimizer = tf.keras.optimizers.Adam() train_generator_fn = gan_training_tf_fns.create_train_generator_fn( gan_loss_fns, server_gen_optimizer) train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn( gan_loss_fns, client_disc_optimizer) if with_dp: dp_average_query = tensorflow_privacy.QuantileAdaptiveClipAverageQuery( initial_l2_norm_clip=BEFORE_DP_L2_NORM_CLIP, noise_multiplier=0.3, target_unclipped_quantile=3, learning_rate=0.1, clipped_count_stddev=0.0, expected_num_records=10, denominator=10.0) else: dp_average_query = None return tff_gans.GanFnsAndTypes( generator_model_fn=one_dim_gan.create_generator, discriminator_model_fn=one_dim_gan.create_discriminator, dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())), dummy_real_data=next(iter(one_dim_gan.create_real_data())), train_generator_fn=train_generator_fn, train_discriminator_fn=train_discriminator_fn, server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=1.0 ), train_discriminator_dp_average_query=dp_average_query)
def test_build_gan_training_process(self, with_dp): gan = _get_gan(with_dp) process = tff_gans.build_gan_training_process(gan) server_state = process.initialize() if with_dp: # Check that initial DP averaging aggregator state is correct. dp_averaging_state = server_state.aggregation_state[0] self.assertAlmostEqual( dp_averaging_state.numerator_state.sum_state.l2_norm_clip, BEFORE_DP_L2_NORM_CLIP, places=5) self.assertAlmostEqual( dp_averaging_state.numerator_state.sum_state.stddev, BEFORE_DP_STD_DEV, places=5) client_dataset_sizes = [1, 3] client_gen_inputs = [ one_dim_gan.create_generator_inputs().take(i) for i in client_dataset_sizes ] client_real_inputs = [ one_dim_gan.create_real_data().take(i) for i in client_dataset_sizes ] num_rounds = 2 for _ in range(num_rounds): server_state = process.next( server_state, one_dim_gan.create_generator_inputs().take(1), client_gen_inputs, client_real_inputs) # Check that server counters have incremented. counters = server_state.counters self.assertDictEqual( counters, { 'num_rounds': num_rounds, 'num_generator_train_examples': one_dim_gan.BATCH_SIZE * num_rounds, 'num_discriminator_train_examples': num_rounds * one_dim_gan.BATCH_SIZE * sum(client_dataset_sizes), }) if with_dp: # Check that DP averaging aggregator state has updated properly over the # above rounds. dp_averaging_state = server_state.aggregation_state[0] self.assertAlmostEqual( dp_averaging_state.numerator_state.sum_state.l2_norm_clip, AFTER_2_RDS_DP_L2_NORM_CLIP, places=5) self.assertAlmostEqual( dp_averaging_state.numerator_state.sum_state.stddev, AFTER_2_RDS_DP_STD_DEV, places=5)
def test_client_computation(self, with_dp): gan = _get_gan(with_dp) client_comp = tff_gans.build_client_computation(gan) generator = gan.generator_model_fn() discriminator = gan.discriminator_model_fn() from_server = gan_training_tf_fns.FromServer( generator_weights=generator.weights, discriminator_weights=discriminator.weights) client_output = client_comp( one_dim_gan.create_generator_inputs().take(10), one_dim_gan.create_real_data().take(10), from_server) self.assertDictEqual( client_output.counters, {'num_discriminator_train_examples': 10 * one_dim_gan.BATCH_SIZE})
def test_client_and_server_computations(self): train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) # N.B. The way we are using datasets and re-using the same # generator and discriminator doesn't really "make sense" from an ML # perspective, but it's sufficient for testing. For more proper usage of # these functions, see training_loops.py. generator = one_dim_gan.create_generator() discriminator = one_dim_gan.create_discriminator() gen_inputs = one_dim_gan.create_generator_inputs() real_data = one_dim_gan.create_real_data() server_state = gan_training_tf_fns.server_initial_state( generator, discriminator) # The aggregation state (e.g., used for handling DP averaging) is # initialized to be empty. A user of the `server_initial_state` is expected # to take the output `ServerState` object and populate this field, most # likely via an instance of tff.templates.AggregationProcess. self.assertEmpty(server_state.aggregation_state) client_output = gan_training_tf_fns.client_computation( gen_inputs.take(3), real_data.take(3), gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights), generator, discriminator, train_discriminator_fn) server_disc_update_optimizer = tf.keras.optimizers.Adam() for _ in range(2): # Train for 2 rounds server_state = gan_training_tf_fns.server_computation( server_state, gen_inputs.take(3), client_output, generator, discriminator, server_disc_update_optimizer, train_generator_fn, NEW_DP_AVERAGING_STATE) counters = self.evaluate(server_state.counters) self.assertDictEqual( counters, { 'num_rounds': 2, 'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE, 'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE }) # DP averaging aggregation state updates properly in server_computation(). self.assertEqual(server_state.aggregation_state, NEW_DP_AVERAGING_STATE)
def test_client_and_server_computations(self): train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) # N.B. The way we are using datasets and re-using the same # generator and discriminator doesn't really "make sense" from an ML # perspective, but it's sufficient for testing. For more proper usage of # these functions, see training_loops.py. generator = one_dim_gan.create_generator() discriminator = one_dim_gan.create_discriminator() gen_inputs = one_dim_gan.create_generator_inputs() real_data = one_dim_gan.create_real_data() server_state = gan_training_tf_fns.server_initial_state( generator, discriminator, INIT_DP_AVERAGING_STATE) # DP averaging aggregation state is initialized properly in # server_initial_state(). self.assertEqual(server_state.dp_averaging_state, INIT_DP_AVERAGING_STATE) client_output = gan_training_tf_fns.client_computation( gen_inputs.take(3), real_data.take(3), gan_training_tf_fns.FromServer( generator_weights=server_state.generator_weights, discriminator_weights=server_state.discriminator_weights), generator, discriminator, train_discriminator_fn) server_disc_update_optimizer = tf.keras.optimizers.Adam() for _ in range(2): # Train for 2 rounds server_state = gan_training_tf_fns.server_computation( server_state, gen_inputs.take(3), client_output, generator, discriminator, server_disc_update_optimizer, train_generator_fn, NEW_DP_AVERAGING_STATE) counters = self.evaluate(server_state.counters) self.assertDictEqual( counters, { 'num_rounds': 2, 'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE, 'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE }) # DP averaging aggregation state updates properly in server_computation(). self.assertEqual(server_state.dp_averaging_state, NEW_DP_AVERAGING_STATE)
def test_tff_training_loop(self, dp_average_query, checkpoint): if checkpoint: root_checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints') else: root_checkpoint_dir = None train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) gan = tff_gans.GanFnsAndTypes( generator_model_fn=one_dim_gan.create_generator, discriminator_model_fn=one_dim_gan.create_discriminator, dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())), dummy_real_data=next(iter(one_dim_gan.create_real_data())), train_generator_fn=train_generator_fn, train_discriminator_fn=train_discriminator_fn, server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD( lr=1.0), train_discriminator_dp_average_query=dp_average_query) gen_inputs = one_dim_gan.create_generator_inputs() real_data = one_dim_gan.create_real_data() client_disc_train_steps = 2 server_gen_train_steps = 3 server_gen_inputs = iter(gen_inputs.window(server_gen_train_steps)) client_gen_inputs = iter(gen_inputs.window(client_disc_train_steps)) client_real_data = iter(real_data.window(client_disc_train_steps)) def server_gen_inputs_fn(_): return next(server_gen_inputs) num_clients = 2 def client_datasets_fn(_): return [(next(client_gen_inputs), next(client_real_data)) for _ in range(num_clients)] server_state, _ = training_loops.federated_training_loop( gan, server_gen_inputs_fn=server_gen_inputs_fn, client_datasets_fn=client_datasets_fn, total_rounds=2, rounds_per_checkpoint=1, root_checkpoint_dir=root_checkpoint_dir) self.assertDictEqual( server_state.counters, { 'num_rounds': 2, 'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE, 'num_discriminator_train_examples': (2 * 2 * one_dim_gan.BATCH_SIZE * num_clients) }) if checkpoint: # TODO(b/141112101): We shouldn't need to re-create the gan, should be # able to reuse the instance from above. See comment inside tff_gans.py. train_generator_fn, train_discriminator_fn = ( _get_train_generator_and_discriminator_fns()) gan = tff_gans.GanFnsAndTypes( generator_model_fn=one_dim_gan.create_generator, discriminator_model_fn=one_dim_gan.create_discriminator, dummy_gen_input=next( iter(one_dim_gan.create_generator_inputs())), dummy_real_data=next(iter(one_dim_gan.create_real_data())), train_generator_fn=train_generator_fn, train_discriminator_fn=train_discriminator_fn, server_disc_update_optimizer_fn=lambda: tf.keras.optimizers. SGD(lr=1.0), train_discriminator_dp_average_query=dp_average_query) # Train one more round, which should resume from the checkpoint. server_state, _ = training_loops.federated_training_loop( gan, server_gen_inputs_fn=server_gen_inputs_fn, client_datasets_fn=client_datasets_fn, total_rounds=3, rounds_per_checkpoint=1, root_checkpoint_dir=root_checkpoint_dir) # Note: It would be better to return something from # federated_training_loop indicating the number of rounds trained in this # invocation, so we could verify the checkpoint was read. self.assertDictEqual( server_state.counters, { 'num_rounds': 3, 'num_generator_train_examples': 3 * 3 * one_dim_gan.BATCH_SIZE, 'num_discriminator_train_examples': (3 * 2 * one_dim_gan.BATCH_SIZE * num_clients) })