def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.compat.v1.enable_v2_behavior() experiment_output_dir = FLAGS.root_output_dir tensorboard_dir = os.path.join(experiment_output_dir, 'logdir', FLAGS.experiment_name) results_dir = os.path.join(experiment_output_dir, 'results', FLAGS.experiment_name) for path in [experiment_output_dir, tensorboard_dir, results_dir]: try: tf.io.gfile.makedirs(path) except tf.errors.OpError: pass # Directory already exists. hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) hparam_dict['results_file'] = results_dir hparams_file = os.path.join(results_dir, 'hparams.csv') logging.info('Saving hyper parameters to: [%s]', hparams_file) utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) train_client_data, test_client_data = ( tff.simulation.datasets.shakespeare.load_data()) def preprocess(ds): return dataset.convert_snippets_to_character_sequence_examples( ds, FLAGS.batch_size, epochs=1).cache() train_dataset = train_client_data.create_tf_dataset_from_all_clients() if FLAGS.shuffle_train_data: train_dataset = train_dataset.shuffle(buffer_size=10000) train_dataset = preprocess(train_dataset) eval_dataset = preprocess( test_client_data.create_tf_dataset_from_all_clients()) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() pad_token, _, _, _ = dataset.get_special_tokens() # Vocabulary with one OOV ID and zero for the mask. vocab_size = len(dataset.CHAR_VOCAB) + 2 model = models.create_recurrent_model(vocab_size=vocab_size, batch_size=FLAGS.batch_size) model.compile( optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[ keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token]) ]) logging.info('Training model:') logging.info(model.summary()) csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=tensorboard_dir) # Reduce the learning rate every 20 epochs. def decay_lr(epoch, lr): if (epoch + 1) % 20 == 0: return lr * 0.1 else: return lr lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1) history = model.fit( train_dataset, validation_data=eval_dataset, epochs=FLAGS.num_epochs, callbacks=[lr_callback, tensorboard_callback, csv_logger_callback]) logging.info('Final metrics:') for name in ['loss', 'accuracy']: metric = history.history['val_{}'.format(name)][-1] logging.info('\t%s: %.4f', name, metric)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() experiment_output_dir = FLAGS.root_output_dir tensorboard_dir = os.path.join(experiment_output_dir, 'logdir', FLAGS.experiment_name) results_dir = os.path.join(experiment_output_dir, 'results', FLAGS.experiment_name) for path in [experiment_output_dir, tensorboard_dir, results_dir]: try: tf.io.gfile.makedirs(path) except tf.errors.OpError: pass # Directory already exists. hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) hparam_dict['results_file'] = results_dir hparams_file = os.path.join(results_dir, 'hparams.csv') logging.info('Saving hyper parameters to: [%s]', hparams_file) utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) train_dataset, eval_dataset = dataset.get_centralized_stackoverflow_datasets( batch_size=FLAGS.batch_size, vocab_tokens_size=FLAGS.vocab_tokens_size, vocab_tags_size=FLAGS.vocab_tags_size) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() model = models.create_logistic_model( vocab_tokens_size=FLAGS.vocab_tokens_size, vocab_tags_size=FLAGS.vocab_tags_size) model.compile(loss=tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.SUM), optimizer=optimizer, metrics=[ tf.keras.metrics.Precision(), tf.keras.metrics.Recall(top_k=5) ]) logging.info('Training model:') logging.info(model.summary()) csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=tensorboard_dir) # Reduce the learning rate after a fixed number of epochs. def decay_lr(epoch, learning_rate): if (epoch + 1) % FLAGS.decay_epochs == 0: return learning_rate * FLAGS.lr_decay else: return learning_rate lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1) history = model.fit( train_dataset, validation_data=eval_dataset, epochs=FLAGS.num_epochs, callbacks=[lr_callback, tensorboard_callback, csv_logger_callback]) logging.info('Final metrics:') for name in ['loss', 'precision', 'recall']: metric = history.history['val_{}'.format(name)][-1] logging.info('\t%s: %.4f', name, metric)
def run_experiment(): """Runs the training experiment.""" _, validation_dataset, test_dataset = dataset.construct_word_level_datasets( FLAGS.vocab_size, FLAGS.batch_size, 1, FLAGS.sequence_length, -1, FLAGS.num_validation_examples) train_dataset = dataset.get_centralized_train_dataset( FLAGS.vocab_size, FLAGS.batch_size, FLAGS.sequence_length, FLAGS.shuffle_buffer_size) model = models.create_recurrent_model( vocab_size=FLAGS.vocab_size, name='stackoverflow-lstm', embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) logging.info('Training model: %s', model.summary()) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() pad_token, oov_token, _, eos_token = dataset.get_special_tokens( FLAGS.vocab_size) model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=optimizer, metrics=[ # Plus 4 for pad, oov, bos, eos keras_metrics.MaskedCategoricalAccuracy( name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov', masked_tokens=[pad_token, oov_token]), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, oov_token, eos_token]), ]) train_results_path = os.path.join(FLAGS.root_output_dir, 'train_results', FLAGS.experiment_name) test_results_path = os.path.join(FLAGS.root_output_dir, 'test_results', FLAGS.experiment_name) train_csv_logger = keras_callbacks.AtomicCSVLogger(train_results_path) test_csv_logger = keras_callbacks.AtomicCSVLogger(test_results_path) log_dir = os.path.join(FLAGS.root_output_dir, 'logdir', FLAGS.experiment_name) try: tf.io.gfile.makedirs(log_dir) tf.io.gfile.makedirs(train_results_path) tf.io.gfile.makedirs(test_results_path) except tf.errors.OpError: pass # log_dir already exists. train_tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=log_dir, write_graph=True, update_freq=FLAGS.tensorboard_update_frequency) test_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir) # Write the hyperparameters to a CSV: hparam_dict = collections.OrderedDict([ (name, FLAGS[name].value) for name in hparam_flags ]) hparams_file = os.path.join(FLAGS.root_output_dir, FLAGS.experiment_name, 'hparams.csv') utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) model.fit( train_dataset, epochs=FLAGS.epochs, verbose=0, validation_data=validation_dataset, callbacks=[train_csv_logger, train_tensorboard_callback]) score = model.evaluate( test_dataset, verbose=0, callbacks=[test_csv_logger, test_tensorboard_callback]) logging.info('Final test loss: %.4f', score[0]) logging.info('Final test accuracy: %.4f', score[1])
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tff.backends.native.set_local_execution_context(max_fanout=10) model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), ] datasets = stackoverflow_dataset.construct_word_level_datasets( FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, FLAGS.sequence_length, FLAGS.max_elements_per_user, FLAGS.num_validation_examples) train_dataset, validation_dataset, test_dataset = datasets if FLAGS.uniform_weighting: def client_weight_fn(local_outputs): del local_outputs return 1.0 else: def client_weight_fn(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=validation_dataset.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) dp_query = tff.utils.build_dp_query( clip=FLAGS.clip, noise_multiplier=FLAGS.noise_multiplier, expected_total_weight=FLAGS.clients_per_round, adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate, target_unclipped_quantile=FLAGS.target_unclipped_quantile, clipped_count_budget_allocation=FLAGS. clipped_count_budget_allocation, expected_num_clients=FLAGS.clients_per_round, per_vector_clipping=FLAGS.per_vector_clipping, model=model_fn()) weights_type = tff.learning.framework.weights_type_from_model(model_fn) aggregation_process = tff.utils.build_dp_aggregate_process( weights_type.trainable, dp_query) else: aggregation_process = None server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, client_optimizer_fn=client_optimizer_fn, aggregation_process=aggregation_process) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset, FLAGS.clients_per_round) evaluate_fn = training_utils.build_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder, assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model) test_fn = training_utils.build_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder, assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run(iterative_process, client_datasets_fn, evaluate_fn, test_fn=test_fn)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') experiment_output_dir = FLAGS.root_output_dir tensorboard_dir = os.path.join(experiment_output_dir, 'logdir', FLAGS.experiment_name) results_dir = os.path.join(experiment_output_dir, 'results', FLAGS.experiment_name) for path in [experiment_output_dir, tensorboard_dir, results_dir]: try: tf.io.gfile.makedirs(path) except tf.errors.OpError: pass # Directory already exists. hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) hparam_dict['results_file'] = results_dir hparams_file = os.path.join(results_dir, 'hparams.csv') logging.info('Saving hyper parameters to: [%s]', hparams_file) utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) train_dataset, eval_dataset = emnist_dataset.get_centralized_emnist_datasets( batch_size=FLAGS.batch_size, only_digits=False) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() if FLAGS.model == 'cnn': model = emnist_models.create_conv_dropout_model(only_digits=False) elif FLAGS.model == '2nn': model = emnist_models.create_two_hidden_layer_model(only_digits=False) else: raise ValueError('Cannot handle model flag [{!s}].'.format( FLAGS.model)) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) logging.info('Training model:') logging.info(model.summary()) csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=tensorboard_dir) # Reduce the learning rate after a fixed number of epochs. def decay_lr(epoch, learning_rate): if (epoch + 1) % FLAGS.decay_epochs == 0: return learning_rate * FLAGS.lr_decay else: return learning_rate lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1) history = model.fit( train_dataset, validation_data=eval_dataset, epochs=FLAGS.num_epochs, callbacks=[lr_callback, tensorboard_callback, csv_logger_callback]) logging.info('Final metrics:') for name in ['loss', 'sparse_categorical_accuracy']: metric = history.history['val_{}'.format(name)][-1] logging.info('\t%s: %.4f', name, metric)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder if FLAGS.task == 'cifar100': hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size federated_cifar100.run_federated( **shared_args, crop_size=FLAGS.cifar100_crop_size, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated( **shared_args, model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **shared_args, sequence_length=FLAGS.shakespeare_sequence_length, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated( **shared_args, **so_nwp_flags, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated( **shared_args, **so_lr_flags, hparam_dict=hparam_dict)
def test_create_optimizer_fn_from_flags_invalid_optimizer(self): FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo' with self.assertRaisesRegex(ValueError, 'not a valid optimizer'): optimizer_utils.create_optimizer_fn_from_flags( TEST_CLIENT_FLAG_PREFIX)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) emnist_train, emnist_test = emnist_dataset.get_emnist_datasets( FLAGS.client_batch_size, FLAGS.client_epochs_per_round, only_digits=False) if FLAGS.model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif FLAGS.model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError('Cannot handle model flag [{!s}].'.format( FLAGS.model)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] if FLAGS.uniform_weighting: def client_weight_fn(local_outputs): del local_outputs return 1.0 else: client_weight_fn = None # Defaults to the number of examples per client. def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=emnist_test.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) dp_query = tff.utils.build_dp_query( clip=FLAGS.clip, noise_multiplier=FLAGS.noise_multiplier, expected_total_weight=FLAGS.clients_per_round, adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate, target_unclipped_quantile=FLAGS.target_unclipped_quantile, clipped_count_budget_allocation=FLAGS. clipped_count_budget_allocation, expected_num_clients=FLAGS.clients_per_round, per_vector_clipping=FLAGS.per_vector_clipping, model=model_fn()) dp_aggregate_fn, _ = tff.utils.build_dp_aggregate(dp_query) else: dp_aggregate_fn = None server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') training_process = ( tff.learning.federated_averaging.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, client_optimizer_fn=client_optimizer_fn, stateful_delta_aggregate_fn=dp_aggregate_fn)) adaptive_clipping = (FLAGS.adaptive_clip_learning_rate > 0) training_process = dp_utils.DPFedAvgProcessAdapter( training_process, FLAGS.per_vector_clipping, adaptive_clipping) client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder, assign_weights_to_keras_model=dp_utils.assign_weights_to_keras_model) logging.info('Training model:') logging.info(model_builder().summary()) training_loop.run( iterative_process=training_process, client_datasets_fn=client_datasets_fn, evaluate_fn=evaluate_fn, )
def from_flags( input_spec, model_builder: ModelBuilder, loss_builder: LossBuilder, metrics_builder: MetricsBuilder, client_weight_fn: Optional[ClientWeightFn] = None, *, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> fed_avg_schedule.FederatedAveragingProcessAdapter: """Builds a `tff.templates.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: input_spec: A value convertible to a `tff.Type`, representing the data which will be fed into the `tff.templates.IterativeProcess.next` function over the course of training. Generally, this can be found by accessing the `element_spec` attribute of a client `tf.data.Dataset`. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. If specified, `input_spec` is optinal, as the necessary type signatures will taken from the computation. Returns: A `fed_avg_schedule.FederatedAveragingProcessAdapter`. """ # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras # model. client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server') if dataset_preprocess_comp is not None: if input_spec is not None: print('Specified both `dataset_preprocess_comp` and `input_spec` when ' 'only one is necessary. Ignoring `input_spec` and using type ' 'signature of `dataset_preprocess_comp`.') model_input_spec = dataset_preprocess_comp.type_signature.result.element else: model_input_spec = input_spec def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=model_input_spec, loss=loss_builder(), metrics=metrics_builder()) return fed_avg_schedule.build_fed_avg_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn, dataset_preprocess_comp=dataset_preprocess_comp)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder shared_args['assign_weights_fn'] = assign_weights_fn if FLAGS.task == 'cifar100': hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size federated_cifar100.run_federated( **shared_args, crop_size=FLAGS.cifar100_crop_size, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated( **shared_args, emnist_model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **shared_args, sequence_length=FLAGS.shakespeare_sequence_length, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated( **shared_args, **so_nwp_flags, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated( **shared_args, **so_lr_flags, hparam_dict=hparam_dict) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS))
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.compat.v1.enable_v2_behavior() experiment_output_dir = FLAGS.root_output_dir tensorboard_dir = os.path.join(experiment_output_dir, 'logdir', FLAGS.experiment_name) results_dir = os.path.join(experiment_output_dir, 'results', FLAGS.experiment_name) for path in [experiment_output_dir, tensorboard_dir, results_dir]: try: tf.io.gfile.makedirs(path) except tf.errors.OpError: pass # Directory already exists. hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) hparam_dict['results_file'] = results_dir hparams_file = os.path.join(results_dir, 'hparams.csv') logging.info('Saving hyper parameters to: [%s]', hparams_file) utils_impl.atomic_write_to_csv(pd.Series(hparam_dict), hparams_file) cifar_train, cifar_test = dataset.get_centralized_cifar100( train_batch_size=FLAGS.batch_size, crop_shape=CROP_SHAPE) optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() model = resnet_models.create_resnet18(input_shape=CROP_SHAPE, num_classes=NUM_CLASSES) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=optimizer, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) logging.info('Training model:') logging.info(model.summary()) csv_logger_callback = keras_callbacks.AtomicCSVLogger(results_dir) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=tensorboard_dir) # Reduce the learning rate after a fixed number of epochs. def decay_lr(epoch, learning_rate): if (epoch + 1) % FLAGS.decay_epochs == 0: return learning_rate * FLAGS.lr_decay else: return learning_rate lr_callback = tf.keras.callbacks.LearningRateScheduler(decay_lr, verbose=1) history = model.fit( cifar_train, validation_data=cifar_test, epochs=FLAGS.num_epochs, callbacks=[lr_callback, tensorboard_callback, csv_logger_callback]) logging.info('Final metrics:') for name in ['loss', 'sparse_categorical_accuracy']: metric = history.history['val_{}'.format(name)][-1] logging.info('\t%s: %.4f', name, metric)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, dataset_preprocess_comp=dataset_preprocess_comp) assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model common_args = collections.OrderedDict([ ('iterative_process_builder', iterative_process_builder), ('assign_weights_fn', assign_weights_fn), ('client_epochs_per_round', FLAGS.client_epochs_per_round), ('client_batch_size', FLAGS.client_batch_size), ('clients_per_round', FLAGS.clients_per_round), ('max_batches_per_client', FLAGS.max_batches_per_client), ('client_datasets_random_seed', FLAGS.client_datasets_random_seed) ]) if FLAGS.task == 'cifar100': federated_cifar100.run_federated(**common_args, crop_size=FLAGS.cifar100_crop_size) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated(**common_args, emnist_model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**common_args) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **common_args, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated(**common_args, **so_nwp_flags) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated(**common_args, **so_lr_flags) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS))