Exemple #1
0
 def client_computation(gen_inputs, real_data, from_server):
   """Returns the client_output."""
   return gan_training_tf_fns.client_computation(
       gen_inputs_ds=gen_inputs,
       real_data_ds=real_data,
       from_server=from_server,
       generator=gan.generator_model_fn(),
       discriminator=gan.discriminator_model_fn(),
       train_discriminator_fn=gan.train_discriminator_fn)
Exemple #2
0
 def client_computation(gen_inputs, real_data, from_server,
                        control_input_gen, control_input_disc):
     """Returns the client_output."""
     return gan_training_tf_fns.client_computation(
         gen_inputs_ds=gen_inputs,
         real_data_ds=real_data,
         from_server=from_server,
         generator=gan.generator_model_fn(),
         discriminator=gan.discriminator_model_fn(),
         gen_optimizer=gen_optimizer_fn(),
         disc_optimizer=disc_optimizer_fn(),
         control_input_gen=control_input_gen,
         control_input_disc=control_input_disc,
         tau=tau)
Exemple #3
0
  def test_client_and_server_computations(self):
    train_generator_fn, train_discriminator_fn = (
        _get_train_generator_and_discriminator_fns())

    # N.B. The way we are using datasets and re-using the same
    # generator and discriminator doesn't really "make sense" from an ML
    # perspective, but it's sufficient for testing. For more proper usage of
    # these functions, see training_loops.py.
    generator = one_dim_gan.create_generator()
    discriminator = one_dim_gan.create_discriminator()
    gen_inputs = one_dim_gan.create_generator_inputs()
    real_data = one_dim_gan.create_real_data()

    server_state = gan_training_tf_fns.server_initial_state(
        generator, discriminator)

    # The aggregation state (e.g., used for handling DP averaging) is
    # initialized to be empty. A user of the `server_initial_state` is expected
    # to take the output `ServerState` object and populate this field, most
    # likely via an instance of tff.templates.AggregationProcess.
    self.assertEmpty(server_state.aggregation_state)

    client_output = gan_training_tf_fns.client_computation(
        gen_inputs.take(3), real_data.take(3),
        gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights),
        generator, discriminator, train_discriminator_fn)

    server_disc_update_optimizer = tf.keras.optimizers.Adam()
    for _ in range(2):  # Train for 2 rounds
      server_state = gan_training_tf_fns.server_computation(
          server_state, gen_inputs.take(3), client_output, generator,
          discriminator, server_disc_update_optimizer, train_generator_fn,
          NEW_DP_AVERAGING_STATE)

    counters = self.evaluate(server_state.counters)
    self.assertDictEqual(
        counters, {
            'num_rounds': 2,
            'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE,
            'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE
        })

    # DP averaging aggregation state updates properly in server_computation().
    self.assertEqual(server_state.aggregation_state, NEW_DP_AVERAGING_STATE)
Exemple #4
0
  def test_client_and_server_computations(self):
    train_generator_fn, train_discriminator_fn = (
        _get_train_generator_and_discriminator_fns())

    # N.B. The way we are using datasets and re-using the same
    # generator and discriminator doesn't really "make sense" from an ML
    # perspective, but it's sufficient for testing. For more proper usage of
    # these functions, see training_loops.py.
    generator = one_dim_gan.create_generator()
    discriminator = one_dim_gan.create_discriminator()
    gen_inputs = one_dim_gan.create_generator_inputs()
    real_data = one_dim_gan.create_real_data()

    server_state = gan_training_tf_fns.server_initial_state(
        generator, discriminator, INIT_DP_AVERAGING_STATE)

    # DP averaging aggregation state is initialized properly in
    # server_initial_state().
    self.assertEqual(server_state.dp_averaging_state, INIT_DP_AVERAGING_STATE)

    client_output = gan_training_tf_fns.client_computation(
        gen_inputs.take(3), real_data.take(3),
        gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights),
        generator, discriminator, train_discriminator_fn)

    server_disc_update_optimizer = tf.keras.optimizers.Adam()
    for _ in range(2):  # Train for 2 rounds
      server_state = gan_training_tf_fns.server_computation(
          server_state, gen_inputs.take(3), client_output, generator,
          discriminator, server_disc_update_optimizer, train_generator_fn,
          NEW_DP_AVERAGING_STATE)

    counters = self.evaluate(server_state.counters)
    self.assertDictEqual(
        counters, {
            'num_rounds': 2,
            'num_discriminator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE,
            'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE
        })

    # DP averaging aggregation state updates properly in server_computation().
    self.assertEqual(server_state.dp_averaging_state, NEW_DP_AVERAGING_STATE)
Exemple #5
0
def simple_training_loop(generator_model_fn,
                         discriminator_model_fn,
                         real_data_fn,
                         gen_inputs_fn,
                         train_generator_fn,
                         train_discriminator_fn,
                         total_rounds=30,
                         client_disc_train_steps=16,
                         server_gen_train_steps=8,
                         rounds_per_eval=10,
                         eval_hook=lambda *args: None):
    """Trains in TF using client_computation and server_computation.

  This is not intended to be a general-purpose training loop (e.g., the
  optimizers are hard-coded), it is primarily intended for testing.

  Args:
    generator_model_fn: A no-arg function return the generator model.
    discriminator_model_fn: A no-arg function return the discriminator model.
    real_data_fn: A no-arg function returning a dataset of real data batches.
    gen_inputs_fn: A no-arg function returning a dataset of generator input
      batches.
    train_generator_fn: A function which takes the two networks and generator
      input and trains the generator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.
    total_rounds: Number of rounds to train.
    client_disc_train_steps: Number of discriminator training batches per round.
    server_gen_train_steps: Number of generator training batches per round.
    rounds_per_eval: How often to call the  `eval_hook` function.
    eval_hook: A function taking arguments (generator, discriminator,
      server_state, round_num) and performs evaluation. Optional.

  Returns:
    A tuple (final `ServerState`, train_time_in_seconds).
  """
    logging.info('Starting simple_training_loop')
    # N.B. We can't use real_data.take(...) in the loops below,
    # or we would get the same examples on every round. Using window
    # essentially breaks one Dataset into a sequence of Datasets,
    # which is exactly what we need here.
    client_gen_inputs = iter(gen_inputs_fn().window(client_disc_train_steps))
    client_real_data = iter(real_data_fn().window(client_disc_train_steps))

    server_gen_inputs = iter(gen_inputs_fn().window(server_gen_train_steps))

    server_generator = generator_model_fn()
    server_discriminator = discriminator_model_fn()
    # We could probably use a single copy of the generator and discriminator, but
    # using separate copies is more faithful to how this code will be used in TFF.
    client_generator = generator_model_fn()
    client_discriminator = discriminator_model_fn()

    server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)

    server_state = gan_training_tf_fns.server_initial_state(
        server_generator, server_discriminator)

    start_time = time.time()

    def do_eval(round_num):
        eval_hook(server_generator, server_discriminator, server_state,
                  round_num)
        elapsed_minutes = (time.time() - start_time) / 60
        print('Total training time {:.2f} minutes for {} rounds '
              '({:.2f} rounds per minute)'.format(elapsed_minutes, round_num,
                                                  round_num / elapsed_minutes),
              flush=True)

    logging.info('Starting training')
    for round_num in range(total_rounds):
        if round_num % rounds_per_eval == 0:
            do_eval(round_num)

        client_output = gan_training_tf_fns.client_computation(
            gen_inputs_ds=next(client_gen_inputs),
            real_data_ds=next(client_real_data),
            from_server=gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights),
            generator=client_generator,
            discriminator=client_discriminator,
            train_discriminator_fn=train_discriminator_fn)

        server_state = gan_training_tf_fns.server_computation(
            server_state=server_state,
            gen_inputs_ds=next(server_gen_inputs),
            client_output=client_output,
            generator=server_generator,
            discriminator=server_discriminator,
            server_disc_update_optimizer=server_disc_update_optimizer,
            train_generator_fn=train_generator_fn)

    train_time = time.time() - start_time
    do_eval(total_rounds)
    return server_state, train_time