Ejemplo n.º 1
0
def _get_gan(with_dp=False):
    gan_loss_fns = gan_losses.get_gan_loss_fns('wasserstein')
    server_gen_optimizer = tf.keras.optimizers.Adam()
    client_disc_optimizer = tf.keras.optimizers.Adam()
    train_generator_fn = gan_training_tf_fns.create_train_generator_fn(
        gan_loss_fns, server_gen_optimizer)
    train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn(
        gan_loss_fns, client_disc_optimizer)

    if with_dp:
        dp_average_query = tensorflow_privacy.QuantileAdaptiveClipAverageQuery(
            initial_l2_norm_clip=BEFORE_DP_L2_NORM_CLIP,
            noise_multiplier=0.3,
            target_unclipped_quantile=3,
            learning_rate=0.1,
            clipped_count_stddev=0.0,
            expected_num_records=10,
            denominator=10.0)
    else:
        dp_average_query = None

    return tff_gans.GanFnsAndTypes(
        generator_model_fn=one_dim_gan.create_generator,
        discriminator_model_fn=one_dim_gan.create_discriminator,
        dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())),
        dummy_real_data=next(iter(one_dim_gan.create_real_data())),
        train_generator_fn=train_generator_fn,
        train_discriminator_fn=train_discriminator_fn,
        server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=1.0
                                                                        ),
        train_discriminator_dp_average_query=dp_average_query)
Ejemplo n.º 2
0
def _get_gan(gen_model_fn, disc_model_fn, gan_loss_fns, gen_optimizer,
             disc_optimizer, server_gen_inputs_dataset,
             client_real_images_tff_data, use_dp, dp_l2_norm_clip,
             dp_noise_multiplier, clients_per_round):
    """Construct instance of tff_gans.GanFnsAndTypes class."""
    dummy_gen_input = next(iter(server_gen_inputs_dataset))
    dummy_real_data = next(
        iter(
            client_real_images_tff_data.create_tf_dataset_for_client(
                client_real_images_tff_data.client_ids[0])))

    train_generator_fn = gan_training_tf_fns.create_train_generator_fn(
        gan_loss_fns, gen_optimizer)
    train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn(
        gan_loss_fns, disc_optimizer)

    dp_average_query = None
    if use_dp:
        dp_average_query = tensorflow_privacy.GaussianAverageQuery(
            l2_norm_clip=dp_l2_norm_clip,
            sum_stddev=dp_l2_norm_clip * dp_noise_multiplier,
            denominator=clients_per_round)

    return tff_gans.GanFnsAndTypes(
        generator_model_fn=gen_model_fn,
        discriminator_model_fn=disc_model_fn,
        dummy_gen_input=dummy_gen_input,
        dummy_real_data=dummy_real_data,
        train_generator_fn=train_generator_fn,
        train_discriminator_fn=train_discriminator_fn,
        server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=1.0
                                                                        ),
        train_discriminator_dp_average_query=dp_average_query)
def _get_train_generator_and_discriminator_fns():
    gan_loss_fns = gan_losses.get_gan_loss_fns('wasserstein')
    train_generator_fn = gan_training_tf_fns.create_train_generator_fn(
        gan_loss_fns, tf.keras.optimizers.Adam())
    train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn(
        gan_loss_fns, tf.keras.optimizers.Adam())
    return train_generator_fn, train_discriminator_fn
Ejemplo n.º 4
0
 def test_create_train_generator_fn(self):
     train_generator_fn = gan_training_tf_fns.create_train_generator_fn(
         GAN_LOSS_FNS, tf.keras.optimizers.Adam())
     self.assertListEqual(
         ['generator', 'discriminator', 'generator_inputs'],
         train_generator_fn.function_spec.fullargspec.args)
Ejemplo n.º 5
0
def _get_train_generator_and_discriminator_fns():
    train_generator_fn = gan_training_tf_fns.create_train_generator_fn(
        GAN_LOSS_FNS, tf.keras.optimizers.Adam())
    train_discriminator_fn = gan_training_tf_fns.create_train_discriminator_fn(
        GAN_LOSS_FNS, tf.keras.optimizers.Adam())
    return train_generator_fn, train_discriminator_fn