Пример #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    logging.set_verbosity(logging.INFO)

    # Flags.
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    for k in hparam_dict.keys():
        if hparam_dict[k] is None:
            hparam_dict[k] = 'None'
    for k, v in hparam_dict.items():
        print('{} : {} '.format(k, v))

    tff.backends.native.set_local_execution_context(
        default_num_clients=FLAGS.num_clients_per_round)

    # Trained classifier model.
    classifier_model = ecm.get_trained_emnist_classifier_model()

    # GAN Models.
    disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim)

    # Training datasets.
    server_gen_inputs_dataset = _create_gen_inputs_dataset(
        batch_size=FLAGS.server_train_batch_size, noise_dim=FLAGS.noise_dim)
    client_gen_inputs_dataset = _create_gen_inputs_dataset(
        batch_size=CLIENT_TRAIN_BATCH_SIZE, noise_dim=FLAGS.noise_dim)

    if FLAGS.filtering == 'by_user':
        client_real_images_train_tff_data = (
            fedu.get_filtered_by_user_client_data_for_training(
                invert_imagery_probability=FLAGS.invert_imagery_probability,
                accuracy_threshold=FLAGS.accuracy_threshold,
                batch_size=CLIENT_TRAIN_BATCH_SIZE))
    elif FLAGS.filtering == 'by_example':
        client_real_images_train_tff_data = (
            fedu.get_filtered_by_example_client_data_for_training(
                invert_imagery_probability=FLAGS.invert_imagery_probability,
                min_num_examples=FLAGS.min_num_examples,
                example_class_selection=FLAGS.example_class_selection,
                batch_size=CLIENT_TRAIN_BATCH_SIZE))
    else:
        client_real_images_train_tff_data = (
            fedu.get_unfiltered_client_data_for_training(
                batch_size=CLIENT_TRAIN_BATCH_SIZE))

    print('There are %d unique clients that will be used for GAN training.' %
          len(client_real_images_train_tff_data.client_ids))

    # Training: GAN Losses and Optimizers.
    gan_loss_fns = gan_losses.WassersteinGanLossFns(
        grad_penalty_lambda=FLAGS.wass_gp_lambda)
    disc_optimizer = tf.keras.optimizers.SGD(lr=0.0005)
    gen_optimizer = tf.keras.optimizers.SGD(lr=0.005)

    # Eval datasets.
    gen_inputs_eval_dataset = _create_gen_inputs_dataset(
        batch_size=EVAL_BATCH_SIZE, noise_dim=FLAGS.noise_dim)
    real_images_eval_dataset = _create_real_images_dataset_for_eval()

    # Eval hook.
    path_to_output_images = _get_path_to_output_image(FLAGS.root_output_dir,
                                                      FLAGS.exp_name)
    logging.info('path_to_output_images is %s', path_to_output_images)
    eval_hook_fn = _get_emnist_eval_hook_fn(
        FLAGS.exp_name, FLAGS.root_output_dir, hparam_dict, gan_loss_fns,
        gen_inputs_eval_dataset, real_images_eval_dataset,
        FLAGS.num_rounds_per_save_images, path_to_output_images,
        classifier_model)

    # Form the GAN.
    gan = _get_gan(gen_model_fn,
                   disc_model_fn,
                   gan_loss_fns,
                   gen_optimizer,
                   disc_optimizer,
                   server_gen_inputs_dataset,
                   client_real_images_train_tff_data,
                   use_dp=FLAGS.use_dp,
                   dp_l2_norm_clip=FLAGS.dp_l2_norm_clip,
                   dp_noise_multiplier=FLAGS.dp_noise_multiplier,
                   clients_per_round=FLAGS.num_clients_per_round)

    # Training.
    _, tff_time = _train(gan,
                         server_gen_inputs_dataset,
                         client_gen_inputs_dataset,
                         client_real_images_train_tff_data,
                         FLAGS.num_client_disc_train_steps,
                         FLAGS.num_server_gen_train_steps,
                         FLAGS.num_clients_per_round,
                         FLAGS.num_rounds,
                         FLAGS.num_rounds_per_eval,
                         eval_hook_fn,
                         FLAGS.num_rounds_per_checkpoint,
                         output_dir=FLAGS.root_output_dir,
                         exp_name=FLAGS.exp_name)
    logging.info('Total training time was %4.3f seconds.', tff_time)

    print('\nTRAINING COMPLETE.')
Пример #2
0
def client_control(
        # Tensor/Dataset arguments that will be supplied by TFF:
        gen_inputs_ds: tf.data.Dataset,
        real_data_ds: tf.data.Dataset,
        from_server: FromServer,
        # Python arguments bound to be bound at TFF computation construction time:
        generator: tf.keras.Model,
        discriminator: tf.keras.Model,
        disc_optimizer: tf.keras.optimizers.Optimizer,
        gen_optimizer: tf.keras.optimizers.Optimizer,
        zero_disc: tf.keras.Model,
        zero_gen: tf.keras.Model,
        tau: float) -> ClientOutput:
    """The computation to run on the client, training the discriminator.

  Args:
    gen_inputs_ds: A `tf.data.Dataset` of generator_inputs.
    real_data_ds: A `tf.data.Dataset` of data from the real distribution.
    from_server: A `FromServer` object, including the current model weights.
    generator:  The generator.
    discriminator: The discriminator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.

  Returns:
    A `ClientOutput` object.
  """
    tf.nest.map_structure(lambda a, b: a.assign(b), generator.weights,
                          from_server.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), discriminator.weights,
                          from_server.discriminator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), zero_gen.weights,
                          from_server.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), zero_disc.weights,
                          from_server.discriminator_weights)
    num_examples = tf.constant(0)
    meta_gen = from_server.meta_gen
    meta_disc = from_server.meta_disc
    gen_inputs_and_real_data = tf.data.Dataset.zip(
        (gen_inputs_ds, real_data_ds))
    loss_fns = gan_losses.WassersteinGanLossFns()
    for gen_inputs, real_data in gen_inputs_and_real_data:
        # It's possible that real_data and gen_inputs have different batch sizes.
        # For calculating the discriminator loss, it's desirable to have equal-sized
        # contributions from both the real and fake data. Also, it's necessary if
        # using the Wasserstein gradient penalty (where a difference is taken b/w
        # the real and fake data). So here we reduce to the min batch size. This
        # also ensures num_examples properly reflects the amount of data trained on.
        min_batch_size = tf.minimum(
            tf.shape(real_data)[0],
            tf.shape(gen_inputs)[0])
        real_data = real_data[0:min_batch_size]
        gen_inputs = gen_inputs[0:min_batch_size]
        # reset the gen/discriminator values so there's no moving
        #tf.nest.map_structure(lambda a, b: a.assign(b), generator.weights,
        #from_server.generator_weights)
        #tf.nest.map_structure(lambda a, b: a.assign(b), discriminator.weights,
        #from_server.discriminator_weights)
        with tf.GradientTape() as tape_gen:
            gen_loss = loss_fns.generator_loss(generator, discriminator,
                                               gen_inputs)
            for i in range(len(generator.weights)):
                gen_loss += tau * tf.nn.l2_loss(generator.weights[i] -
                                                meta_gen[i])
        with tf.GradientTape() as tape_disc:
            disc_loss = loss_fns.discriminator_loss(generator, discriminator,
                                                    gen_inputs, real_data)
            for i in range(len(discriminator.weights)):
                disc_loss += tau * tf.nn.l2_loss(discriminator.weights[i] -
                                                 meta_disc[i])
        # get disc grads
        disc_grads = tape_disc.gradient(disc_loss, discriminator.weights)
        disc_grads_and_vars = zip(disc_grads, discriminator.weights)

        # get gen grads
        gen_grads = tape_gen.gradient(gen_loss, generator.weights)
        gen_grads_and_vars = zip(gen_grads, generator.weights)

        disc_grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v),
                                                    disc_grads,
                                                    zero_disc.weights)
        gen_grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v),
                                                   gen_grads, zero_gen.weights)
        #apply the gradients
        disc_optimizer.apply_gradients(disc_grads_and_vars)
        gen_optimizer.apply_gradients(gen_grads_and_vars)

        #find the deltas
        #disc_delta = tf.nest.map_structure(tf.subtract, discriminator.weights,
        #from_server.discriminator_weights)

        #gen_delta = tf.nest.map_structure(tf.subtract, generator.weights,
        #from_server.generator_weights)
        # add to buffers
        #zero_disc = tf.nest.map_structure(lambda a, b: a + b, zero_disc, disc_delta)
        #zero_gen = tf.nest.map_structure(lambda a, b: a + b, zero_gen, gen_delta)

        num_examples += min_batch_size
    num_examples_float = tf.cast(num_examples, tf.float32)
    disc_delta = tf.nest.map_structure(
        lambda a, b: (a - b) / num_examples_float, zero_disc.weights,
        from_server.discriminator_weights)
    gen_delta = tf.nest.map_structure(
        lambda a, b: (a - b) / num_examples_float, zero_gen.weights,
        from_server.generator_weights)

    disc_delta, disc_has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(disc_delta))
    gen_delta, gen_has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(gen_delta))
    update_weight = tf.cast(num_examples, tf.float32)
    # Zero out the weight if there are any non-finite values.
    # TODO(b/122071074): federated_mean might not do the right thing if
    # all clients have zero weight.
    update_weight_disc = tf.cond(tf.equal(disc_has_non_finite_delta,
                                          0), lambda: update_weight,
                                 lambda: tf.constant(0.0))
    update_weight_gen = tf.cond(tf.equal(gen_has_non_finite_delta,
                                         0), lambda: update_weight,
                                lambda: tf.constant(0.0))
    update_weight = tf.math.minimum(update_weight_disc, update_weight_gen)
    return ClientOutput(
        discriminator_weights_delta=disc_delta,
        generator_weights_delta=gen_delta,
        update_weight=update_weight,
        counters={'num_discriminator_train_examples': num_examples})