Exemple #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    logging.set_verbosity(logging.INFO)

    # Flags.
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    for k in hparam_dict.keys():
        if hparam_dict[k] is None:
            hparam_dict[k] = 'None'
    for k, v in hparam_dict.items():
        print('{} : {} '.format(k, v))

    tff.framework.set_default_executor(
        tff.framework.local_executor_factory(
            num_clients=FLAGS.num_clients_per_round))

    # Trained classifier model.
    classifier_model = ecm.get_trained_emnist_classifier_model()

    # GAN Models.
    disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim)

    # Training datasets.
    server_gen_inputs_dataset = _create_gen_inputs_dataset(
        batch_size=FLAGS.server_train_batch_size, noise_dim=FLAGS.noise_dim)
    client_gen_inputs_dataset = _create_gen_inputs_dataset(
        batch_size=CLIENT_TRAIN_BATCH_SIZE, noise_dim=FLAGS.noise_dim)

    if FLAGS.filtering == 'by_user':
        client_real_images_train_tff_data = (
            fedu.get_filtered_by_user_client_data_for_training(
                invert_imagery_probability=FLAGS.invert_imagery_probability,
                accuracy_threshold=FLAGS.accuracy_threshold,
                batch_size=CLIENT_TRAIN_BATCH_SIZE))
    elif FLAGS.filtering == 'by_example':
        client_real_images_train_tff_data = (
            fedu.get_filtered_by_example_client_data_for_training(
                invert_imagery_probability=FLAGS.invert_imagery_probability,
                min_num_examples=FLAGS.min_num_examples,
                example_class_selection=FLAGS.example_class_selection,
                batch_size=CLIENT_TRAIN_BATCH_SIZE))
    else:
        client_real_images_train_tff_data = (
            fedu.get_unfiltered_client_data_for_training(
                batch_size=CLIENT_TRAIN_BATCH_SIZE))

    print('There are %d unique clients that will be used for GAN training.' %
          len(client_real_images_train_tff_data.client_ids))

    # Training: GAN Losses and Optimizers.
    gan_loss_fns = gan_losses.WassersteinGanLossFns(
        grad_penalty_lambda=FLAGS.wass_gp_lambda)
    disc_optimizer = tf.keras.optimizers.SGD(lr=0.0005)
    gen_optimizer = tf.keras.optimizers.SGD(lr=0.005)

    # Eval datasets.
    gen_inputs_eval_dataset = _create_gen_inputs_dataset(
        batch_size=EVAL_BATCH_SIZE, noise_dim=FLAGS.noise_dim)
    real_images_eval_dataset = _create_real_images_dataset_for_eval()

    # Eval hook.
    path_to_output_images = _get_path_to_output_image(FLAGS.root_output_dir,
                                                      FLAGS.exp_name)
    logging.info('path_to_output_images is %s', path_to_output_images)
    eval_hook_fn = _get_emnist_eval_hook_fn(
        FLAGS.exp_name, FLAGS.root_output_dir, hparam_dict, gan_loss_fns,
        gen_inputs_eval_dataset, real_images_eval_dataset,
        FLAGS.num_rounds_per_save_images, path_to_output_images,
        classifier_model)

    # Form the GAN.
    gan = _get_gan(gen_model_fn,
                   disc_model_fn,
                   gan_loss_fns,
                   gen_optimizer,
                   disc_optimizer,
                   server_gen_inputs_dataset,
                   client_real_images_train_tff_data,
                   use_dp=FLAGS.use_dp,
                   dp_l2_norm_clip=FLAGS.dp_l2_norm_clip,
                   dp_noise_multiplier=FLAGS.dp_noise_multiplier,
                   clients_per_round=FLAGS.num_clients_per_round)

    # Training.
    _, tff_time = _train(gan,
                         server_gen_inputs_dataset,
                         client_gen_inputs_dataset,
                         client_real_images_train_tff_data,
                         FLAGS.num_client_disc_train_steps,
                         FLAGS.num_server_gen_train_steps,
                         FLAGS.num_clients_per_round,
                         FLAGS.num_rounds,
                         FLAGS.num_rounds_per_eval,
                         eval_hook_fn,
                         FLAGS.num_rounds_per_checkpoint,
                         output_dir=FLAGS.root_output_dir,
                         exp_name=FLAGS.exp_name)
    logging.info('Total training time was %4.3f seconds.', tff_time)

    print('\nTRAINING COMPLETE.')
def server_computation(
    # Tensor/Dataset arguments that will be supplied by TFF:
    server_state: ServerState,
    gen_inputs_ds: tf.data.Dataset,
    client_output: ClientOutput,
    # Python arguments to be bound at TFF computation construction time:
    generator: tf.keras.Model,
    discriminator: tf.keras.Model,
    state_gen_optimizer: tf.keras.optimizers.Optimizer,
    state_disc_optimizer: tf.keras.optimizers.Optimizer,
    # Not an argument bound at TFF computation construction time, but placed
    # last so that it can be defaulted to empty tuple (for non-DP use cases).
    new_dp_averaging_state=()
) -> ServerState:
    """The computation to run on the server, training the generator.

  Args:
    server_state: The initial `ServerState` for the round.
    gen_inputs_ds: An infinite `tf.data.Dataset` of inputs to the `generator`.
    client_output: The (possibly aggregated) `ClientOutput`.
    generator:  The generator.
    discriminator: The discriminator.
    server_disc_update_optimizer: Optimizer used to `apply_gradients` based on
      the client_output delta.
    train_generator_fn: A function which takes the two networks and generator
      input and trains the generator.
    new_dp_averaging_state: The updated state of the DP averaging aggregator.

  Returns:
    An updated `ServerState` object.
  """
    # A tf.function can't modify the structure of its input arguments,
    # so we make a semi-shallow copy:
    server_state = attr.evolve(server_state,
                               counters=dict(server_state.counters))

    tf.nest.map_structure(conditioned_assign, state_gen_optimizer.variables(),
                          server_state.state_gen_optimizer_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), _weights(generator),
                          server_state.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), _weights(discriminator),
                          server_state.discriminator_weights)
    server_gen_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1)
    server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1)

    delta = client_output.discriminator_weights_delta
    tf.nest.assert_same_structure(delta, discriminator.trainable_weights)
    grads_and_vars = tf.nest.map_structure(lambda x, v: (-1.0 * x, v), delta,
                                           discriminator.trainable_weights)
    server_disc_update_optimizer.apply_gradients(grads_and_vars,
                                                 name='server_update_disc')

    for k, v in client_output.counters.items():
        server_state.counters[k] += v

    # Update the state of the DP averaging aggregator.
    server_state.dp_averaging_state = new_dp_averaging_state

    gen_examples_this_round = tf.constant(0)

    loss_fns = gan_losses.WassersteinGanLossFns()
    for gen_inputs in gen_inputs_ds:  # Compiled by autograph.
        with tf.GradientTape() as tape2:
            loss2 = loss_fns.generator_loss(generator, discriminator,
                                            gen_inputs)
        grads2 = tape2.gradient(loss2, generator.trainable_variables)
        grads_and_vars2 = zip(grads2, generator.trainable_variables)
        state_gen_optimizer.apply_gradients(grads_and_vars2)
        gen_examples_this_round += tf.shape(gen_inputs)[0]
    # update discriminator optimizer
    delta_opt_D = client_output.state_disc_opt_delta
    updated_opt_D = tf.nest.map_structure(
        lambda a, b: a + b, server_state.state_disc_optimizer_weights,
        delta_opt_D)

    G_change = tf.nest.map_structure(tf.subtract, generator.trainable_weights,
                                     server_state.generator_weights.trainable)
    D_change = tf.nest.map_structure(
        tf.subtract, discriminator.trainable_weights,
        server_state.discriminator_weights.trainable)
    server_state.state_gen_optimizer_weights = tf.nest.map_structure(
        lambda x: tf.cast(x, tf.float32), state_gen_optimizer.variables())
    server_state.state_disc_optimizer_weights = updated_opt_D
    server_state.counters[
        'num_generator_train_examples'] += gen_examples_this_round
    server_state.generator_weights = _weights(generator)
    server_state.discriminator_weights = _weights(discriminator)
    server_state.counters['num_rounds'] += 1
    server_state.generator_diff = G_change
    server_state.discriminator_diff = D_change
    return server_state
Exemple #3
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  logging.set_verbosity(logging.INFO)

  # Flags.
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  for k in hparam_dict.keys():
    if hparam_dict[k] is None:
      hparam_dict[k] = 'None'
  for k, v in hparam_dict.items():
    print('{} : {} '.format(k, v))

  
  tff.backends.native.set_local_execution_context(
      num_clients=FLAGS.num_clients_per_round)
  
  exp_name = 'client_steps={0},clients_per_round={1},status={2},model_name={3},client_batch={4},optimizer={5},lr={6},lr_factor={7}'.format(
    FLAGS.num_client_disc_train_steps, FLAGS.num_clients_per_round, FLAGS.status, FLAGS.model,
    FLAGS.client_batch_size, FLAGS.optimizer, FLAGS.lr,FLAGS.lr_factor
  )
  cache_dir = None
  
  cache_dir = os.path.join(os.path.join(FLAGS.root_output_dir, exp_name))
  cache_subdir = os.path.join(cache_dir, 'datasets')
  if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)
  if not os.path.exists(cache_subdir):
    os.makedirs(cache_subdir)
  
  if not os.path.exists(os.path.join(cache_subdir, 'fed_emnist_digitsonly_train.h5')):
    copyfile('/home/houc/.keras/datasets/fed_emnist.tar.bz2', os.path.join(cache_subdir, 'fed_emnist.tar.bz2'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_test.h5', os.path.join(cache_subdir, 'fed_emnist_test.h5'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_train.h5', os.path.join(cache_subdir, 'fed_emnist_train.h5'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly.tar.bz2', os.path.join(cache_subdir, 'fed_emnist_digitsonly.tar.bz2'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly_test.h5', os.path.join(cache_subdir, 'fed_emnist_digitsonly_test.h5'))
    copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly_train.h5', os.path.join(cache_subdir, 'fed_emnist_digitsonly_train.h5'))
  
  client_batch_size = FLAGS.client_batch_size
  # Trained classifier model.
  classifier_model = ecm.get_trained_emnist_classifier_model()

  # GAN Models.
  if FLAGS.model == 'spectral':
    disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim, True)
  else:
    disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim)

  # Training datasets.
  server_gen_inputs_dataset = _create_gen_inputs_dataset(
      batch_size=FLAGS.server_train_batch_size, noise_dim=FLAGS.noise_dim)
  client_gen_inputs_dataset = _create_gen_inputs_dataset(
      batch_size=client_batch_size, noise_dim=FLAGS.noise_dim)
  '''
  if FLAGS.filtering == 'by_user':
    client_real_images_train_tff_data = (
        fedu.get_filtered_by_user_client_data_for_training(
            invert_imagery_probability=FLAGS.invert_imagery_probability,
            accuracy_threshold=FLAGS.accuracy_threshold,
            batch_size=CLIENT_TRAIN_BATCH_SIZE))
  elif FLAGS.filtering == 'by_example':
    client_real_images_train_tff_data = (
        fedu.get_filtered_by_example_client_data_for_training(
            invert_imagery_probability=FLAGS.invert_imagery_probability,
            min_num_examples=FLAGS.min_num_examples,
            example_class_selection=FLAGS.example_class_selection,
            batch_size=CLIENT_TRAIN_BATCH_SIZE))
  else:
    client_real_images_train_tff_data = (
        fedu.get_unfiltered_client_data_for_training(
            batch_size=CLIENT_TRAIN_BATCH_SIZE))
  '''
  if FLAGS.status == 'CEN_CEN':
    #cache_dir = None
    
    if FLAGS.model.split('_')[0].lower() == 'mnist':
      (train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
      def preprocess_images(images):
        images = np.float32(images.reshape((images.shape[0], 28, 28, 1))/255)
        return images
      train_images = preprocess_images(train_images)
      test_images = preprocess_images(test_images)
      train_size = 60000
      test_size = 10000
      central_dataset = (tf.data.Dataset.from_tensor_slices(train_images)
                  .shuffle(train_size, reshuffle_each_iteration=True).batch(client_batch_size))
    else:
      #cache_dir = None
      central_dataset = _create_real_images_dataset_for_central(client_batch_size, cache_dir)
      
      
      
    print('Dataset done', flush = True)
    def create_tf_dataset_for_client(client_id):
      return central_dataset
    client_real_images_train_tff_data = tff.simulation.ClientData.from_clients_and_fn(["1"], create_tf_dataset_for_client)
    
    
  else:
    client_real_images_train_tff_data = (
        fedu.get_unfiltered_client_data_for_training(
            batch_size=client_batch_size, cache_dir=cache_dir))
  print('There are %d unique clients that will be used for GAN training.' %
        len(client_real_images_train_tff_data.client_ids))

  # Training: GAN Losses and Optimizers.
  gan_loss_fns = gan_losses.WassersteinGanLossFns(
      grad_penalty_lambda=FLAGS.wass_gp_lambda)
  '''
  if FLAGS.status == 'CEN_CEN' or FLAGS.status == 'LOC_LOC':
    #disc_optimizer = tf.keras.optimizers.SGD(lr=0.0004)
    #gen_optimizer = tf.keras.optimizers.SGD(lr=0.0001)
    gen_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
      [1000], [0.001, 0.0005]
    )
    disc_optimizer = tf.keras.optimizers.SGD(learning_rate=0.0004)
    gen_optimizer = tf.keras.optimizers.SGD(learning_rate=gen_sched)
  else:
    disc_optimizer = tf.keras.optimizers.SGD(learning_rate=0.0002)
    gen_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
  '''
  disc_optimizer = tf.keras.optimizers.SGD(learning_rate=1)
  gen_optimizer = tf.keras.optimizers.SGD(learning_rate=1)
  # Eval datasets.
  gen_inputs_eval_dataset = _create_gen_inputs_dataset(
      batch_size=EVAL_BATCH_SIZE, noise_dim=FLAGS.noise_dim)
  if FLAGS.model.split('_')[0].lower() == 'mnist':
    real_images_eval_dataset = (tf.data.Dataset.from_tensor_slices(test_images)
                      .shuffle(test_size, reshuffle_each_iteration=True).batch(EVAL_BATCH_SIZE))
  else:
    real_images_eval_dataset = _create_real_images_dataset_for_eval(cache_dir)

  # Eval hook.
  path_to_output_images = _get_path_to_output_image(FLAGS.root_output_dir,
                                                    exp_name)
  logging.info('path_to_output_images is %s', path_to_output_images)
  #num_rounds_per_save_images = max(int(FLAGS.num_rounds/100),1)
  eval_hook_fn = _get_emnist_eval_hook_fn(
      exp_name, FLAGS.root_output_dir, hparam_dict, gan_loss_fns,
      gen_inputs_eval_dataset, real_images_eval_dataset,
      FLAGS.num_rounds_per_save_images, path_to_output_images, classifier_model)

  # Form the GAN
  '''
  gan = _get_gan(
      gen_model_fn,
      disc_model_fn,
      gan_loss_fns,
      gen_optimizer,
      disc_optimizer,
      server_gen_inputs_dataset,
      client_real_images_train_tff_data,
      use_dp=FLAGS.use_dp,
      dp_l2_norm_clip=FLAGS.dp_l2_norm_clip,
      dp_noise_multiplier=FLAGS.dp_noise_multiplier,
      clients_per_round=FLAGS.num_clients_per_round)
  '''
  statussplit = FLAGS.status.split('_')
  gan = _get_gan(
      gen_model_fn,
      disc_model_fn,
      gan_loss_fns,
      gen_optimizer,
      disc_optimizer,
      server_gen_inputs_dataset,
      client_real_images_train_tff_data,
      use_dp=FLAGS.use_dp,
      dp_l2_norm_clip=FLAGS.dp_l2_norm_clip,
      dp_noise_multiplier=FLAGS.dp_noise_multiplier,
      clients_per_round=FLAGS.num_clients_per_round,
      gen_status=statussplit[0],
      disc_status=statussplit[1],
      learning_rate=FLAGS.lr,
      optimizer=FLAGS.optimizer,
      client_disc_train_steps=FLAGS.num_client_disc_train_steps,
      lr_factor=FLAGS.lr_factor)  
  
  # Training.
  '''
  _, tff_time = _train(
      gan,
      server_gen_inputs_dataset,
      client_gen_inputs_dataset,
      client_real_images_train_tff_data,
      FLAGS.num_client_disc_train_steps,
      FLAGS.num_server_gen_train_steps,
      FLAGS.num_clients_per_round,
      FLAGS.num_rounds,
      FLAGS.num_rounds_per_eval,
      eval_hook_fn,
      FLAGS.num_rounds_per_checkpoint,
      output_dir=FLAGS.root_output_dir,
      exp_name=FLAGS.exp_name)
  '''
  #num_rounds_per_eval = max(int(FLAGS.num_rounds/100),1)
  _, tff_time = _train(
      gan,
      server_gen_inputs_dataset,
      client_gen_inputs_dataset,
      client_real_images_train_tff_data,
      FLAGS.num_client_disc_train_steps,
      FLAGS.num_server_gen_train_steps,
      FLAGS.num_clients_per_round,
      FLAGS.num_rounds,
      FLAGS.num_rounds_per_eval,
      eval_hook_fn,
      FLAGS.num_rounds_per_checkpoint,
      output_dir=FLAGS.root_output_dir,
      exp_name=exp_name) 
  logging.info('Total training time was %4.3f seconds.', tff_time)

  print('\nTRAINING COMPLETE.')
def client_computation_fedadam(
        # Tensor/Dataset arguments that will be supplied by TFF:
        gen_inputs_ds: tf.data.Dataset,
        real_data_ds: tf.data.Dataset,
        from_server: FromServer,
        # Python arguments bound to be bound at TFF computation construction time:
        generator: tf.keras.Model,
        discriminator: tf.keras.Model,
        state_gen_optimizer: tf.keras.optimizers.Optimizer,
        state_disc_optimizer: tf.keras.optimizers.Optimizer) -> ClientOutput:
    """The computation to run on the client, training the discriminator.

  Args:
    gen_inputs_ds: A `tf.data.Dataset` of generator_inputs.
    real_data_ds: A `tf.data.Dataset` of data from the real distribution.
    from_server: A `FromServer` object, including the current model weights.
    generator:  The generator.
    discriminator: The discriminator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.

  Returns:
    A `ClientOutput` object.
  """
    tf.nest.map_structure(lambda a, b: a.assign(b), _weights(generator),
                          from_server.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), _weights(discriminator),
                          from_server.discriminator_weights)
    tf.nest.map_structure(conditioned_assign, state_disc_optimizer.variables(),
                          from_server.state_disc_optimizer_weights)
    tf.nest.map_structure(conditioned_assign, state_gen_optimizer.variables(),
                          from_server.state_gen_optimizer_weights)
    num_examples = tf.constant(0)
    loss_fns = gan_losses.WassersteinGanLossFns()
    gen_inputs_and_real_data = tf.data.Dataset.zip(
        (gen_inputs_ds, real_data_ds))

    sgd_optimizer = tf.keras.optimizers.SGD(learning_rate=1)
    for gen_inputs, real_data in gen_inputs_and_real_data:
        # It's possible that real_data and gen_inputs have different batch sizes.
        # For calculating the discriminator loss, it's desirable to have equal-sized
        # contributions from both the real and fake data. Also, it's necessary if
        # using the Wasserstein gradient penalty (where a difference is taken b/w
        # the real and fake data). So here we reduce to the min batch size. This
        # also ensures num_examples properly reflects the amount of data trained on.
        min_batch_size = tf.minimum(
            tf.shape(real_data)[0],
            tf.shape(gen_inputs)[0])
        real_data = real_data[0:min_batch_size]
        gen_inputs = gen_inputs[0:min_batch_size]
        with tf.GradientTape() as tape:
            loss = loss_fns.discriminator_loss(generator, discriminator,
                                               gen_inputs, real_data)
        grads = tape.gradient(loss, discriminator.trainable_variables)
        grads_and_vars = zip(grads, discriminator.trainable_variables)
        sgd_optimizer.apply_gradients(grads_and_vars)
        num_examples += min_batch_size

    state_disc_opt_delta = tf.nest.map_structure(
        tf.subtract,
        tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                              state_disc_optimizer.variables()),
        from_server.state_disc_optimizer_weights)
    # should be zero
    state_gen_opt_delta = tf.nest.map_structure(
        tf.subtract,
        tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                              state_gen_optimizer.variables()),
        from_server.state_gen_optimizer_weights)
    weights_delta = tf.nest.map_structure(
        tf.subtract, discriminator.trainable_weights,
        from_server.discriminator_weights.trainable)
    weights_delta, has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    update_weight = tf.cast(num_examples, tf.float32)
    # Zero out the weight if there are any non-finite values.
    # TODO(b/122071074): federated_mean might not do the right thing if
    # all clients have zero weight.
    update_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                            lambda: update_weight, lambda: tf.constant(0.0))
    return ClientOutput(
        discriminator_weights_delta=weights_delta,
        generator_weights_delta=weights_delta,
        state_disc_opt_delta=state_disc_opt_delta,
        state_gen_opt_delta=state_gen_opt_delta,
        update_weight_D=update_weight,
        update_weight_G=update_weight,
        update_weight=update_weight,
        counters={'num_discriminator_train_examples': num_examples})