Example #1
0
def _train(gan, server_gen_inputs_dataset, client_gen_inputs_dataset,
           client_real_images_tff_data, client_disc_train_steps,
           server_gen_train_steps, clients_per_round, total_rounds,
           rounds_per_eval, eval_hook_fn, rounds_per_checkpoint, output_dir,
           exp_name):
  """Trains the federated GAN."""
  server_gen_inputs_iterator = iter(
      server_gen_inputs_dataset.window(server_gen_train_steps))
  client_gen_inputs_iterator = iter(
      client_gen_inputs_dataset.window(client_disc_train_steps))

  def server_gen_inputs_fn(round_num):
    del round_num
    return next(server_gen_inputs_iterator)

  def client_datasets_fn(round_num):
    """Forms clients_per_round number of datasets for a round of computation."""
    del round_num
    if (gan.gen_status == 'LOC' and gan.disc_status == 'LOC') or (gan.gen_status == 'CEN' and gan.disc_status == 'CEN'):
      client_ids = [client_real_images_tff_data.client_ids[0]]
    else:
      client_ids = np.random.choice(
          client_real_images_tff_data.client_ids,
          size=clients_per_round,
          replace=False)
    datasets = []
    for client_id in client_ids:
      datasets.append((next(client_gen_inputs_iterator),
                       client_real_images_tff_data.create_tf_dataset_for_client(
                           client_id).repeat(client_disc_train_steps).take(client_disc_train_steps)))
    return datasets

  return training_loops.federated_training_loop(
      gan,
      server_gen_inputs_fn=server_gen_inputs_fn,
      client_datasets_fn=client_datasets_fn,
      total_rounds=total_rounds,
      rounds_per_eval=rounds_per_eval,
      eval_hook=eval_hook_fn,
      rounds_per_checkpoint=rounds_per_checkpoint,
      root_checkpoint_dir=os.path.join(output_dir,
                                       'checkpoints/{}'.format(exp_name)))
    def test_tff_training_loop(self, dp_average_query, checkpoint):
        if checkpoint:
            root_checkpoint_dir = os.path.join(self.get_temp_dir(),
                                               'checkpoints')
        else:
            root_checkpoint_dir = None

        train_generator_fn, train_discriminator_fn = (
            _get_train_generator_and_discriminator_fns())

        gan = tff_gans.GanFnsAndTypes(
            generator_model_fn=one_dim_gan.create_generator,
            discriminator_model_fn=one_dim_gan.create_discriminator,
            dummy_gen_input=next(iter(one_dim_gan.create_generator_inputs())),
            dummy_real_data=next(iter(one_dim_gan.create_real_data())),
            train_generator_fn=train_generator_fn,
            train_discriminator_fn=train_discriminator_fn,
            server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.SGD(
                lr=1.0),
            train_discriminator_dp_average_query=dp_average_query)

        gen_inputs = one_dim_gan.create_generator_inputs()
        real_data = one_dim_gan.create_real_data()

        client_disc_train_steps = 2
        server_gen_train_steps = 3

        server_gen_inputs = iter(gen_inputs.window(server_gen_train_steps))
        client_gen_inputs = iter(gen_inputs.window(client_disc_train_steps))
        client_real_data = iter(real_data.window(client_disc_train_steps))

        def server_gen_inputs_fn(_):
            return next(server_gen_inputs)

        num_clients = 2

        def client_datasets_fn(_):
            return [(next(client_gen_inputs), next(client_real_data))
                    for _ in range(num_clients)]

        server_state, _ = training_loops.federated_training_loop(
            gan,
            server_gen_inputs_fn=server_gen_inputs_fn,
            client_datasets_fn=client_datasets_fn,
            total_rounds=2,
            rounds_per_checkpoint=1,
            root_checkpoint_dir=root_checkpoint_dir)

        self.assertDictEqual(
            server_state.counters, {
                'num_rounds':
                2,
                'num_generator_train_examples':
                2 * 3 * one_dim_gan.BATCH_SIZE,
                'num_discriminator_train_examples':
                (2 * 2 * one_dim_gan.BATCH_SIZE * num_clients)
            })
        if checkpoint:
            # TODO(b/141112101): We shouldn't need to re-create the gan, should be
            # able to reuse the instance from above. See comment inside tff_gans.py.
            train_generator_fn, train_discriminator_fn = (
                _get_train_generator_and_discriminator_fns())
            gan = tff_gans.GanFnsAndTypes(
                generator_model_fn=one_dim_gan.create_generator,
                discriminator_model_fn=one_dim_gan.create_discriminator,
                dummy_gen_input=next(
                    iter(one_dim_gan.create_generator_inputs())),
                dummy_real_data=next(iter(one_dim_gan.create_real_data())),
                train_generator_fn=train_generator_fn,
                train_discriminator_fn=train_discriminator_fn,
                server_disc_update_optimizer_fn=lambda: tf.keras.optimizers.
                SGD(lr=1.0),
                train_discriminator_dp_average_query=dp_average_query)
            # Train one more round, which should resume from the checkpoint.
            server_state, _ = training_loops.federated_training_loop(
                gan,
                server_gen_inputs_fn=server_gen_inputs_fn,
                client_datasets_fn=client_datasets_fn,
                total_rounds=3,
                rounds_per_checkpoint=1,
                root_checkpoint_dir=root_checkpoint_dir)
            # Note: It would be better to return something from
            # federated_training_loop indicating the number of rounds trained in this
            # invocation, so we could verify the checkpoint was read.
            self.assertDictEqual(
                server_state.counters, {
                    'num_rounds':
                    3,
                    'num_generator_train_examples':
                    3 * 3 * one_dim_gan.BATCH_SIZE,
                    'num_discriminator_train_examples':
                    (3 * 2 * one_dim_gan.BATCH_SIZE * num_clients)
                })