Exemple #1
0
 def server_initial_state():
     generator = gan.generator_model_fn()
     discriminator = gan.discriminator_model_fn()
     dp_averaging_state = (() if gan.dp_averaging_fn is None else
                           gan.dp_averaging_fn.initialize())
     return gan_training_tf_fns.server_initial_state(
         generator, discriminator, dp_averaging_state)
 def server_initial_state():
     generator = gan.generator_model_fn()
     discriminator = gan.discriminator_model_fn()
     gen_opt = gan.state_gen_optimizer_fn(0.001)
     gan_training_tf_fns.initialize_optimizer_vars(generator, gen_opt)
     disc_opt = gan.state_disc_optimizer_fn(0.0002)
     gan_training_tf_fns.initialize_optimizer_vars(discriminator, disc_opt)
     dp_averaging_state = (() if gan.dp_averaging_fn is None else
                           gan.dp_averaging_fn.initialize())
     return gan_training_tf_fns.server_initial_state(
         generator, discriminator, gen_opt, disc_opt, dp_averaging_state)
    def test_client_and_server_computations(self):
        train_generator_fn, train_discriminator_fn = (
            _get_train_generator_and_discriminator_fns())

        # N.B. The way we are using datasets and re-using the same
        # generator and discriminator doesn't really "make sense" from an ML
        # perspective, but it's sufficient for testing. For more proper usage of
        # these functions, see training_loops.py.
        generator = one_dim_gan.create_generator()
        discriminator = one_dim_gan.create_discriminator()
        gen_inputs = one_dim_gan.create_generator_inputs()
        real_data = one_dim_gan.create_real_data()

        server_state = gan_training_tf_fns.server_initial_state(
            generator, discriminator, INIT_DP_AVERAGING_STATE)

        # DP averaging aggregation state is initialized properly in
        # server_initial_state().
        self.assertEqual(server_state.dp_averaging_state,
                         INIT_DP_AVERAGING_STATE)

        client_output = gan_training_tf_fns.client_computation(
            gen_inputs.take(3), real_data.take(3),
            gan_training_tf_fns.FromServer(
                generator_weights=server_state.generator_weights,
                discriminator_weights=server_state.discriminator_weights),
            generator, discriminator, train_discriminator_fn)

        server_disc_update_optimizer = tf.keras.optimizers.Adam()
        for _ in range(2):  # Train for 2 rounds
            server_state = gan_training_tf_fns.server_computation(
                server_state, gen_inputs.take(3), client_output, generator,
                discriminator, server_disc_update_optimizer,
                train_generator_fn, NEW_DP_AVERAGING_STATE)

        counters = self.evaluate(server_state.counters)
        self.assertDictEqual(
            counters, {
                'num_rounds': 2,
                'num_discriminator_train_examples':
                2 * 3 * one_dim_gan.BATCH_SIZE,
                'num_generator_train_examples': 2 * 3 * one_dim_gan.BATCH_SIZE
            })

        # DP averaging aggregation state updates properly in server_computation().
        self.assertEqual(server_state.dp_averaging_state,
                         NEW_DP_AVERAGING_STATE)
Exemple #4
0
def simple_training_loop(generator_model_fn,
                         discriminator_model_fn,
                         real_data_fn,
                         gen_inputs_fn,
                         train_generator_fn,
                         train_discriminator_fn,
                         total_rounds=30,
                         client_disc_train_steps=16,
                         server_gen_train_steps=8,
                         rounds_per_eval=10,
                         eval_hook=lambda *args: None):
  """Trains in TF using client_computation and server_computation.

  This is not intended to be a general-purpose training loop (e.g., the
  optimizers are hard-coded), it is primarily intended for testing.

  Args:
    generator_model_fn: A no-arg function return the generator model.
    discriminator_model_fn: A no-arg function return the discriminator model.
    real_data_fn: A no-arg function returning a dataset of real data batches.
    gen_inputs_fn: A no-arg function returning a dataset of generator input
      batches.
    train_generator_fn: A function which takes the two networks and generator
      input and trains the generator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.
    total_rounds: Number of rounds to train.
    client_disc_train_steps: Number of discriminator training batches per round.
    server_gen_train_steps: Number of generator training batches per round.
    rounds_per_eval: How often to call the  `eval_hook` function.
    eval_hook: A function taking arguments (generator, discriminator,
      server_state, round_num) and performs evaluation. Optional.

  Returns:
    A tuple (final `ServerState`, train_time_in_seconds).
  """
  logging.info('Starting simple_training_loop')
  # N.B. We can't use real_data.take(...) in the loops below,
  # or we would get the same examples on every round. Using window
  # essentially breaks one Dataset into a sequence of Datasets,
  # which is exactly what we need here.
  client_gen_inputs = iter(gen_inputs_fn().window(client_disc_train_steps))
  client_real_data = iter(real_data_fn().window(client_disc_train_steps))

  server_gen_inputs = iter(gen_inputs_fn().window(server_gen_train_steps))

  server_generator = generator_model_fn()
  server_discriminator = discriminator_model_fn()
  # We could probably use a single copy of the generator and discriminator, but
  # using separate copies is more faithful to how this code will be used in TFF.
  client_generator = generator_model_fn()
  client_discriminator = discriminator_model_fn()

  server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1.0)

  server_state = gan_training_tf_fns.server_initial_state(
      server_generator, server_discriminator)

  start_time = time.time()

  def do_eval(round_num):
    eval_hook(server_generator, server_discriminator, server_state, round_num)
    elapsed_minutes = (time.time() - start_time) / 60
    print(
        'Total training time {:.2f} minutes for {} rounds '
        '({:.2f} rounds per minute)'.format(elapsed_minutes, round_num,
                                            round_num / elapsed_minutes),
        flush=True)

  logging.info('Starting training')
  for round_num in range(total_rounds):
    if round_num % rounds_per_eval == 0:
      do_eval(round_num)

    client_output = gan_training_tf_fns.client_computation(
        gen_inputs_ds=next(client_gen_inputs),
        real_data_ds=next(client_real_data),
        from_server=gan_training_tf_fns.FromServer(
            generator_weights=server_state.generator_weights,
            discriminator_weights=server_state.discriminator_weights),
        generator=client_generator,
        discriminator=client_discriminator,
        train_discriminator_fn=train_discriminator_fn)

    server_state = gan_training_tf_fns.server_computation(
        server_state=server_state,
        gen_inputs_ds=next(server_gen_inputs),
        client_output=client_output,
        generator=server_generator,
        discriminator=server_discriminator,
        server_disc_update_optimizer=server_disc_update_optimizer,
        train_generator_fn=train_generator_fn)

  train_time = time.time() - start_time
  do_eval(total_rounds)
  return server_state, train_time