def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.set_verbosity(logging.INFO) # Flags. hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) for k in hparam_dict.keys(): if hparam_dict[k] is None: hparam_dict[k] = 'None' for k, v in hparam_dict.items(): print('{} : {} '.format(k, v)) tff.framework.set_default_executor( tff.framework.local_executor_factory( num_clients=FLAGS.num_clients_per_round)) # Trained classifier model. classifier_model = ecm.get_trained_emnist_classifier_model() # GAN Models. disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim) # Training datasets. server_gen_inputs_dataset = _create_gen_inputs_dataset( batch_size=FLAGS.server_train_batch_size, noise_dim=FLAGS.noise_dim) client_gen_inputs_dataset = _create_gen_inputs_dataset( batch_size=CLIENT_TRAIN_BATCH_SIZE, noise_dim=FLAGS.noise_dim) if FLAGS.filtering == 'by_user': client_real_images_train_tff_data = ( fedu.get_filtered_by_user_client_data_for_training( invert_imagery_probability=FLAGS.invert_imagery_probability, accuracy_threshold=FLAGS.accuracy_threshold, batch_size=CLIENT_TRAIN_BATCH_SIZE)) elif FLAGS.filtering == 'by_example': client_real_images_train_tff_data = ( fedu.get_filtered_by_example_client_data_for_training( invert_imagery_probability=FLAGS.invert_imagery_probability, min_num_examples=FLAGS.min_num_examples, example_class_selection=FLAGS.example_class_selection, batch_size=CLIENT_TRAIN_BATCH_SIZE)) else: client_real_images_train_tff_data = ( fedu.get_unfiltered_client_data_for_training( batch_size=CLIENT_TRAIN_BATCH_SIZE)) print('There are %d unique clients that will be used for GAN training.' % len(client_real_images_train_tff_data.client_ids)) # Training: GAN Losses and Optimizers. gan_loss_fns = gan_losses.WassersteinGanLossFns( grad_penalty_lambda=FLAGS.wass_gp_lambda) disc_optimizer = tf.keras.optimizers.SGD(lr=0.0005) gen_optimizer = tf.keras.optimizers.SGD(lr=0.005) # Eval datasets. gen_inputs_eval_dataset = _create_gen_inputs_dataset( batch_size=EVAL_BATCH_SIZE, noise_dim=FLAGS.noise_dim) real_images_eval_dataset = _create_real_images_dataset_for_eval() # Eval hook. path_to_output_images = _get_path_to_output_image(FLAGS.root_output_dir, FLAGS.exp_name) logging.info('path_to_output_images is %s', path_to_output_images) eval_hook_fn = _get_emnist_eval_hook_fn( FLAGS.exp_name, FLAGS.root_output_dir, hparam_dict, gan_loss_fns, gen_inputs_eval_dataset, real_images_eval_dataset, FLAGS.num_rounds_per_save_images, path_to_output_images, classifier_model) # Form the GAN. gan = _get_gan(gen_model_fn, disc_model_fn, gan_loss_fns, gen_optimizer, disc_optimizer, server_gen_inputs_dataset, client_real_images_train_tff_data, use_dp=FLAGS.use_dp, dp_l2_norm_clip=FLAGS.dp_l2_norm_clip, dp_noise_multiplier=FLAGS.dp_noise_multiplier, clients_per_round=FLAGS.num_clients_per_round) # Training. _, tff_time = _train(gan, server_gen_inputs_dataset, client_gen_inputs_dataset, client_real_images_train_tff_data, FLAGS.num_client_disc_train_steps, FLAGS.num_server_gen_train_steps, FLAGS.num_clients_per_round, FLAGS.num_rounds, FLAGS.num_rounds_per_eval, eval_hook_fn, FLAGS.num_rounds_per_checkpoint, output_dir=FLAGS.root_output_dir, exp_name=FLAGS.exp_name) logging.info('Total training time was %4.3f seconds.', tff_time) print('\nTRAINING COMPLETE.')
def server_computation( # Tensor/Dataset arguments that will be supplied by TFF: server_state: ServerState, gen_inputs_ds: tf.data.Dataset, client_output: ClientOutput, # Python arguments to be bound at TFF computation construction time: generator: tf.keras.Model, discriminator: tf.keras.Model, state_gen_optimizer: tf.keras.optimizers.Optimizer, state_disc_optimizer: tf.keras.optimizers.Optimizer, # Not an argument bound at TFF computation construction time, but placed # last so that it can be defaulted to empty tuple (for non-DP use cases). new_dp_averaging_state=() ) -> ServerState: """The computation to run on the server, training the generator. Args: server_state: The initial `ServerState` for the round. gen_inputs_ds: An infinite `tf.data.Dataset` of inputs to the `generator`. client_output: The (possibly aggregated) `ClientOutput`. generator: The generator. discriminator: The discriminator. server_disc_update_optimizer: Optimizer used to `apply_gradients` based on the client_output delta. train_generator_fn: A function which takes the two networks and generator input and trains the generator. new_dp_averaging_state: The updated state of the DP averaging aggregator. Returns: An updated `ServerState` object. """ # A tf.function can't modify the structure of its input arguments, # so we make a semi-shallow copy: server_state = attr.evolve(server_state, counters=dict(server_state.counters)) tf.nest.map_structure(conditioned_assign, state_gen_optimizer.variables(), server_state.state_gen_optimizer_weights) tf.nest.map_structure(lambda a, b: a.assign(b), _weights(generator), server_state.generator_weights) tf.nest.map_structure(lambda a, b: a.assign(b), _weights(discriminator), server_state.discriminator_weights) server_gen_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1) server_disc_update_optimizer = tf.keras.optimizers.SGD(learning_rate=1) delta = client_output.discriminator_weights_delta tf.nest.assert_same_structure(delta, discriminator.trainable_weights) grads_and_vars = tf.nest.map_structure(lambda x, v: (-1.0 * x, v), delta, discriminator.trainable_weights) server_disc_update_optimizer.apply_gradients(grads_and_vars, name='server_update_disc') for k, v in client_output.counters.items(): server_state.counters[k] += v # Update the state of the DP averaging aggregator. server_state.dp_averaging_state = new_dp_averaging_state gen_examples_this_round = tf.constant(0) loss_fns = gan_losses.WassersteinGanLossFns() for gen_inputs in gen_inputs_ds: # Compiled by autograph. with tf.GradientTape() as tape2: loss2 = loss_fns.generator_loss(generator, discriminator, gen_inputs) grads2 = tape2.gradient(loss2, generator.trainable_variables) grads_and_vars2 = zip(grads2, generator.trainable_variables) state_gen_optimizer.apply_gradients(grads_and_vars2) gen_examples_this_round += tf.shape(gen_inputs)[0] # update discriminator optimizer delta_opt_D = client_output.state_disc_opt_delta updated_opt_D = tf.nest.map_structure( lambda a, b: a + b, server_state.state_disc_optimizer_weights, delta_opt_D) G_change = tf.nest.map_structure(tf.subtract, generator.trainable_weights, server_state.generator_weights.trainable) D_change = tf.nest.map_structure( tf.subtract, discriminator.trainable_weights, server_state.discriminator_weights.trainable) server_state.state_gen_optimizer_weights = tf.nest.map_structure( lambda x: tf.cast(x, tf.float32), state_gen_optimizer.variables()) server_state.state_disc_optimizer_weights = updated_opt_D server_state.counters[ 'num_generator_train_examples'] += gen_examples_this_round server_state.generator_weights = _weights(generator) server_state.discriminator_weights = _weights(discriminator) server_state.counters['num_rounds'] += 1 server_state.generator_diff = G_change server_state.discriminator_diff = D_change return server_state
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.set_verbosity(logging.INFO) # Flags. hparam_dict = collections.OrderedDict([ (name, FLAGS[name].value) for name in hparam_flags ]) for k in hparam_dict.keys(): if hparam_dict[k] is None: hparam_dict[k] = 'None' for k, v in hparam_dict.items(): print('{} : {} '.format(k, v)) tff.backends.native.set_local_execution_context( num_clients=FLAGS.num_clients_per_round) exp_name = 'client_steps={0},clients_per_round={1},status={2},model_name={3},client_batch={4},optimizer={5},lr={6},lr_factor={7}'.format( FLAGS.num_client_disc_train_steps, FLAGS.num_clients_per_round, FLAGS.status, FLAGS.model, FLAGS.client_batch_size, FLAGS.optimizer, FLAGS.lr,FLAGS.lr_factor ) cache_dir = None cache_dir = os.path.join(os.path.join(FLAGS.root_output_dir, exp_name)) cache_subdir = os.path.join(cache_dir, 'datasets') if not os.path.exists(cache_dir): os.makedirs(cache_dir) if not os.path.exists(cache_subdir): os.makedirs(cache_subdir) if not os.path.exists(os.path.join(cache_subdir, 'fed_emnist_digitsonly_train.h5')): copyfile('/home/houc/.keras/datasets/fed_emnist.tar.bz2', os.path.join(cache_subdir, 'fed_emnist.tar.bz2')) copyfile('/home/houc/.keras/datasets/fed_emnist_test.h5', os.path.join(cache_subdir, 'fed_emnist_test.h5')) copyfile('/home/houc/.keras/datasets/fed_emnist_train.h5', os.path.join(cache_subdir, 'fed_emnist_train.h5')) copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly.tar.bz2', os.path.join(cache_subdir, 'fed_emnist_digitsonly.tar.bz2')) copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly_test.h5', os.path.join(cache_subdir, 'fed_emnist_digitsonly_test.h5')) copyfile('/home/houc/.keras/datasets/fed_emnist_digitsonly_train.h5', os.path.join(cache_subdir, 'fed_emnist_digitsonly_train.h5')) client_batch_size = FLAGS.client_batch_size # Trained classifier model. classifier_model = ecm.get_trained_emnist_classifier_model() # GAN Models. if FLAGS.model == 'spectral': disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim, True) else: disc_model_fn, gen_model_fn = _get_gan_network_models(FLAGS.noise_dim) # Training datasets. server_gen_inputs_dataset = _create_gen_inputs_dataset( batch_size=FLAGS.server_train_batch_size, noise_dim=FLAGS.noise_dim) client_gen_inputs_dataset = _create_gen_inputs_dataset( batch_size=client_batch_size, noise_dim=FLAGS.noise_dim) ''' if FLAGS.filtering == 'by_user': client_real_images_train_tff_data = ( fedu.get_filtered_by_user_client_data_for_training( invert_imagery_probability=FLAGS.invert_imagery_probability, accuracy_threshold=FLAGS.accuracy_threshold, batch_size=CLIENT_TRAIN_BATCH_SIZE)) elif FLAGS.filtering == 'by_example': client_real_images_train_tff_data = ( fedu.get_filtered_by_example_client_data_for_training( invert_imagery_probability=FLAGS.invert_imagery_probability, min_num_examples=FLAGS.min_num_examples, example_class_selection=FLAGS.example_class_selection, batch_size=CLIENT_TRAIN_BATCH_SIZE)) else: client_real_images_train_tff_data = ( fedu.get_unfiltered_client_data_for_training( batch_size=CLIENT_TRAIN_BATCH_SIZE)) ''' if FLAGS.status == 'CEN_CEN': #cache_dir = None if FLAGS.model.split('_')[0].lower() == 'mnist': (train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data() def preprocess_images(images): images = np.float32(images.reshape((images.shape[0], 28, 28, 1))/255) return images train_images = preprocess_images(train_images) test_images = preprocess_images(test_images) train_size = 60000 test_size = 10000 central_dataset = (tf.data.Dataset.from_tensor_slices(train_images) .shuffle(train_size, reshuffle_each_iteration=True).batch(client_batch_size)) else: #cache_dir = None central_dataset = _create_real_images_dataset_for_central(client_batch_size, cache_dir) print('Dataset done', flush = True) def create_tf_dataset_for_client(client_id): return central_dataset client_real_images_train_tff_data = tff.simulation.ClientData.from_clients_and_fn(["1"], create_tf_dataset_for_client) else: client_real_images_train_tff_data = ( fedu.get_unfiltered_client_data_for_training( batch_size=client_batch_size, cache_dir=cache_dir)) print('There are %d unique clients that will be used for GAN training.' % len(client_real_images_train_tff_data.client_ids)) # Training: GAN Losses and Optimizers. gan_loss_fns = gan_losses.WassersteinGanLossFns( grad_penalty_lambda=FLAGS.wass_gp_lambda) ''' if FLAGS.status == 'CEN_CEN' or FLAGS.status == 'LOC_LOC': #disc_optimizer = tf.keras.optimizers.SGD(lr=0.0004) #gen_optimizer = tf.keras.optimizers.SGD(lr=0.0001) gen_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay( [1000], [0.001, 0.0005] ) disc_optimizer = tf.keras.optimizers.SGD(learning_rate=0.0004) gen_optimizer = tf.keras.optimizers.SGD(learning_rate=gen_sched) else: disc_optimizer = tf.keras.optimizers.SGD(learning_rate=0.0002) gen_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001) ''' disc_optimizer = tf.keras.optimizers.SGD(learning_rate=1) gen_optimizer = tf.keras.optimizers.SGD(learning_rate=1) # Eval datasets. gen_inputs_eval_dataset = _create_gen_inputs_dataset( batch_size=EVAL_BATCH_SIZE, noise_dim=FLAGS.noise_dim) if FLAGS.model.split('_')[0].lower() == 'mnist': real_images_eval_dataset = (tf.data.Dataset.from_tensor_slices(test_images) .shuffle(test_size, reshuffle_each_iteration=True).batch(EVAL_BATCH_SIZE)) else: real_images_eval_dataset = _create_real_images_dataset_for_eval(cache_dir) # Eval hook. path_to_output_images = _get_path_to_output_image(FLAGS.root_output_dir, exp_name) logging.info('path_to_output_images is %s', path_to_output_images) #num_rounds_per_save_images = max(int(FLAGS.num_rounds/100),1) eval_hook_fn = _get_emnist_eval_hook_fn( exp_name, FLAGS.root_output_dir, hparam_dict, gan_loss_fns, gen_inputs_eval_dataset, real_images_eval_dataset, FLAGS.num_rounds_per_save_images, path_to_output_images, classifier_model) # Form the GAN ''' gan = _get_gan( gen_model_fn, disc_model_fn, gan_loss_fns, gen_optimizer, disc_optimizer, server_gen_inputs_dataset, client_real_images_train_tff_data, use_dp=FLAGS.use_dp, dp_l2_norm_clip=FLAGS.dp_l2_norm_clip, dp_noise_multiplier=FLAGS.dp_noise_multiplier, clients_per_round=FLAGS.num_clients_per_round) ''' statussplit = FLAGS.status.split('_') gan = _get_gan( gen_model_fn, disc_model_fn, gan_loss_fns, gen_optimizer, disc_optimizer, server_gen_inputs_dataset, client_real_images_train_tff_data, use_dp=FLAGS.use_dp, dp_l2_norm_clip=FLAGS.dp_l2_norm_clip, dp_noise_multiplier=FLAGS.dp_noise_multiplier, clients_per_round=FLAGS.num_clients_per_round, gen_status=statussplit[0], disc_status=statussplit[1], learning_rate=FLAGS.lr, optimizer=FLAGS.optimizer, client_disc_train_steps=FLAGS.num_client_disc_train_steps, lr_factor=FLAGS.lr_factor) # Training. ''' _, tff_time = _train( gan, server_gen_inputs_dataset, client_gen_inputs_dataset, client_real_images_train_tff_data, FLAGS.num_client_disc_train_steps, FLAGS.num_server_gen_train_steps, FLAGS.num_clients_per_round, FLAGS.num_rounds, FLAGS.num_rounds_per_eval, eval_hook_fn, FLAGS.num_rounds_per_checkpoint, output_dir=FLAGS.root_output_dir, exp_name=FLAGS.exp_name) ''' #num_rounds_per_eval = max(int(FLAGS.num_rounds/100),1) _, tff_time = _train( gan, server_gen_inputs_dataset, client_gen_inputs_dataset, client_real_images_train_tff_data, FLAGS.num_client_disc_train_steps, FLAGS.num_server_gen_train_steps, FLAGS.num_clients_per_round, FLAGS.num_rounds, FLAGS.num_rounds_per_eval, eval_hook_fn, FLAGS.num_rounds_per_checkpoint, output_dir=FLAGS.root_output_dir, exp_name=exp_name) logging.info('Total training time was %4.3f seconds.', tff_time) print('\nTRAINING COMPLETE.')
def client_computation_fedadam( # Tensor/Dataset arguments that will be supplied by TFF: gen_inputs_ds: tf.data.Dataset, real_data_ds: tf.data.Dataset, from_server: FromServer, # Python arguments bound to be bound at TFF computation construction time: generator: tf.keras.Model, discriminator: tf.keras.Model, state_gen_optimizer: tf.keras.optimizers.Optimizer, state_disc_optimizer: tf.keras.optimizers.Optimizer) -> ClientOutput: """The computation to run on the client, training the discriminator. Args: gen_inputs_ds: A `tf.data.Dataset` of generator_inputs. real_data_ds: A `tf.data.Dataset` of data from the real distribution. from_server: A `FromServer` object, including the current model weights. generator: The generator. discriminator: The discriminator. train_discriminator_fn: A function which takes the two networks, generator input, and real data and trains the discriminator. Returns: A `ClientOutput` object. """ tf.nest.map_structure(lambda a, b: a.assign(b), _weights(generator), from_server.generator_weights) tf.nest.map_structure(lambda a, b: a.assign(b), _weights(discriminator), from_server.discriminator_weights) tf.nest.map_structure(conditioned_assign, state_disc_optimizer.variables(), from_server.state_disc_optimizer_weights) tf.nest.map_structure(conditioned_assign, state_gen_optimizer.variables(), from_server.state_gen_optimizer_weights) num_examples = tf.constant(0) loss_fns = gan_losses.WassersteinGanLossFns() gen_inputs_and_real_data = tf.data.Dataset.zip( (gen_inputs_ds, real_data_ds)) sgd_optimizer = tf.keras.optimizers.SGD(learning_rate=1) for gen_inputs, real_data in gen_inputs_and_real_data: # It's possible that real_data and gen_inputs have different batch sizes. # For calculating the discriminator loss, it's desirable to have equal-sized # contributions from both the real and fake data. Also, it's necessary if # using the Wasserstein gradient penalty (where a difference is taken b/w # the real and fake data). So here we reduce to the min batch size. This # also ensures num_examples properly reflects the amount of data trained on. min_batch_size = tf.minimum( tf.shape(real_data)[0], tf.shape(gen_inputs)[0]) real_data = real_data[0:min_batch_size] gen_inputs = gen_inputs[0:min_batch_size] with tf.GradientTape() as tape: loss = loss_fns.discriminator_loss(generator, discriminator, gen_inputs, real_data) grads = tape.gradient(loss, discriminator.trainable_variables) grads_and_vars = zip(grads, discriminator.trainable_variables) sgd_optimizer.apply_gradients(grads_and_vars) num_examples += min_batch_size state_disc_opt_delta = tf.nest.map_structure( tf.subtract, tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), state_disc_optimizer.variables()), from_server.state_disc_optimizer_weights) # should be zero state_gen_opt_delta = tf.nest.map_structure( tf.subtract, tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), state_gen_optimizer.variables()), from_server.state_gen_optimizer_weights) weights_delta = tf.nest.map_structure( tf.subtract, discriminator.trainable_weights, from_server.discriminator_weights.trainable) weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(weights_delta)) update_weight = tf.cast(num_examples, tf.float32) # Zero out the weight if there are any non-finite values. # TODO(b/122071074): federated_mean might not do the right thing if # all clients have zero weight. update_weight = tf.cond(tf.equal(has_non_finite_delta, 0), lambda: update_weight, lambda: tf.constant(0.0)) return ClientOutput( discriminator_weights_delta=weights_delta, generator_weights_delta=weights_delta, state_disc_opt_delta=state_disc_opt_delta, state_gen_opt_delta=state_gen_opt_delta, update_weight_D=update_weight, update_weight_G=update_weight, update_weight=update_weight, counters={'num_discriminator_train_examples': num_examples})