Example #1
0
def _get_gan(gen_model_fn, disc_model_fn, gan_loss_fns, gen_optimizer,
             disc_optimizer, server_gen_inputs_dataset,
             client_real_images_tff_data, use_dp, dp_l2_norm_clip,
             dp_noise_multiplier, clients_per_round):
    """Construct instance of tff_gans.GanFnsAndTypes class."""
    dummy_gen_input = next(iter(server_gen_inputs_dataset))
    dummy_real_data = next(
        iter(
            client_real_images_tff_data.create_tf_dataset_for_client(
                client_real_images_tff_data.client_ids[0])))

    train_generator_fn = gan_training_tf_fns.create_train_generator_fn(
        gan_loss_fns, gen_optimizer)
    train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn(
        gan_loss_fns, disc_optimizer)

    dp_average_query = None
    if use_dp:
        dp_average_query = tensorflow_privacy.GaussianAverageQuery(
            l2_norm_clip=dp_l2_norm_clip,
            sum_stddev=dp_l2_norm_clip * dp_noise_multiplier,
            denominator=clients_per_round)

    return tff_gans.GanFnsAndTypes(
        generator_model_fn=gen_model_fn,
        discriminator_model_fn=disc_model_fn,
        dummy_gen_input=dummy_gen_input,
        dummy_real_data=dummy_real_data,
        train_generator_fn=train_generator_fn,
        train_discriminator_fn=train_discriminator_fn,
        server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=1.0
                                                                        ),
        train_discriminator_dp_average_query=dp_average_query)
Example #2
0
def _get_gan(with_dp=False):
    gan_loss_fns = gan_losses.get_gan_loss_fns('wasserstein')
    server_gen_optimizer = tf.keras.optimizers.Adam()
    client_disc_optimizer = tf.keras.optimizers.Adam()
    train_generator_fn = gan_training_tf_fns.create_train_generator_fn(
        gan_loss_fns, server_gen_optimizer)
    train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn(
        gan_loss_fns, client_disc_optimizer)

    if with_dp:
        dp_average_query = tensorflow_privacy.QuantileAdaptiveClipAverageQuery(
            initial_l2_norm_clip=BEFORE_DP_L2_NORM_CLIP,
            noise_multiplier=0.3,
            target_unclipped_quantile=3,
            learning_rate=0.1,
            clipped_count_stddev=0.0,
            expected_num_records=10,
            denominator=10.0)
    else:
        dp_average_query = None

    return tff_gans.GanFnsAndTypes(
        generator_model_fn=one_dim_gan.create_generator,
        discriminator_model_fn=one_dim_gan.create_discriminator,
        dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())),
        dummy_real_data=next(iter(one_dim_gan.create_real_data())),
        train_generator_fn=train_generator_fn,
        train_discriminator_fn=train_discriminator_fn,
        server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=1.0
                                                                        ),
        train_discriminator_dp_average_query=dp_average_query)
Example #3
0
    def test_tff_training_loop(self, dp_average_query, checkpoint):
        if checkpoint:
            root_checkpoint_dir = os.path.join(self.get_temp_dir(),
                                               'checkpoints')
        else:
            root_checkpoint_dir = None

        train_generator_fn, train_discriminator_fn = (
            _get_train_generator_and_discriminator_fns())

        gan = tff_gans.GanFnsAndTypes(
            generator_model_fn=one_dim_gan.create_generator,
            discriminator_model_fn=one_dim_gan.create_discriminator,
            dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())),
            dummy_real_data=next(iter(one_dim_gan.create_real_data())),
            train_generator_fn=train_generator_fn,
            train_discriminator_fn=train_discriminator_fn,
            server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(
                lr=1.0),
            train_discriminator_dp_average_query=dp_average_query)

        gen_inputs = one_dim_gan.create_generator_inputs()
        real_data = one_dim_gan.create_real_data()

        client_disc_train_steps = 2
        server_gen_train_steps = 3

        server_gen_inputs = iter(gen_inputs.window(server_gen_train_steps))
        client_gen_inputs = iter(gen_inputs.window(client_disc_train_steps))
        client_real_data = iter(real_data.window(client_disc_train_steps))

        def server_gen_inputs_fn(_):
            return next(server_gen_inputs)

        num_clients = 2

        def client_datasets_fn(_):
            return [(next(client_gen_inputs), next(client_real_data))
                    for _ in range(num_clients)]

        server_state, _ = training_loops.federated_training_loop(
            gan,
            server_gen_inputs_fn=server_gen_inputs_fn,
            client_datasets_fn=client_datasets_fn,
            total_rounds=2,
            rounds_per_checkpoint=1,
            root_checkpoint_dir=root_checkpoint_dir)

        self.assertDictEqual(
            server_state.counters, {
                'num_rounds':
                2,
                'num_generator_train_examples':
                2 * 3 * one_dim_gan.BATCH_SIZE,
                'num_discriminator_train_examples':
                (2 * 2 * one_dim_gan.BATCH_SIZE * num_clients)
            })
        if checkpoint:
            # TODO(b/141112101): We shouldn't need to re-create the gan, should be
            # able to reuse the instance from above. See comment inside tff_gans.py.
            train_generator_fn, train_discriminator_fn = (
                _get_train_generator_and_discriminator_fns())
            gan = tff_gans.GanFnsAndTypes(
                generator_model_fn=one_dim_gan.create_generator,
                discriminator_model_fn=one_dim_gan.create_discriminator,
                dummy_gen_input=next(
                    iter(one_dim_gan.create_generator_inputs())),
                dummy_real_data=next(iter(one_dim_gan.create_real_data())),
                train_generator_fn=train_generator_fn,
                train_discriminator_fn=train_discriminator_fn,
                server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.
                SGD(lr=1.0),
                train_discriminator_dp_average_query=dp_average_query)
            # Train one more round, which should resume from the checkpoint.
            server_state, _ = training_loops.federated_training_loop(
                gan,
                server_gen_inputs_fn=server_gen_inputs_fn,
                client_datasets_fn=client_datasets_fn,
                total_rounds=3,
                rounds_per_checkpoint=1,
                root_checkpoint_dir=root_checkpoint_dir)
            # Note: It would be better to return something from
            # federated_training_loop indicating the number of rounds trained in this
            # invocation, so we could verify the checkpoint was read.
            self.assertDictEqual(
                server_state.counters, {
                    'num_rounds':
                    3,
                    'num_generator_train_examples':
                    3 * 3 * one_dim_gan.BATCH_SIZE,
                    'num_discriminator_train_examples':
                    (3 * 2 * one_dim_gan.BATCH_SIZE * num_clients)
                })