def test_create_client_optimizer_from_flags(self, optimizer_name, optimizer_cls): with flag_sandbox( {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): optimizer_name}): # Construct a default optimizer. default_optimizer = utils_impl.create_optimizer_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertIsInstance(default_optimizer, optimizer_cls) # Override the default flag value. overridden_learning_rate = 5.0 custom_optimizer = utils_impl.create_optimizer_from_flags( TEST_CLIENT_FLAG_PREFIX, overrides={'learning_rate': overridden_learning_rate}) self.assertIsInstance(custom_optimizer, optimizer_cls) self.assertEqual(custom_optimizer.get_config()['learning_rate'], overridden_learning_rate) # Override learning rate flag. commandline_set_learning_rate = 100.0 with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): commandline_set_learning_rate }): custom_optimizer = utils_impl.create_optimizer_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertIsInstance(custom_optimizer, optimizer_cls) self.assertEqual( custom_optimizer.get_config()['learning_rate'], commandline_set_learning_rate)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags( 'client') server_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags( 'server') compression_dict = utils_impl.lookup_flag_values(compression_flags) dp_dict = utils_impl.lookup_flag_values(dp_flags) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`.""" model_trainable_variables = model_fn().trainable_variables # Most logic for deciding what to run is here. aggregation_factory = fl_utils.build_aggregator( compression_flags=compression_dict, dp_flags=dp_dict, num_clients=get_total_num_clients(FLAGS.task), num_clients_per_round=FLAGS.clients_per_round, num_rounds=FLAGS.total_rounds, client_template=model_trainable_variables) return tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weighting=tff.learning.ClientWeighting.UNIFORM, client_optimizer_fn=client_optimizer_fn, model_update_aggregation_factory=aggregation_factory) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training(task_spec) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
def test_create_optimizer_from_flags_invalid_overrides(self): with flag_sandbox( {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd'}): with self.assertRaisesRegex(TypeError, 'type `collections.Mapping`'): _ = utils_impl.create_optimizer_from_flags( TEST_CLIENT_FLAG_PREFIX, overrides=[1, 2, 3])
def test_create_optimizer_from_flags_flags_set_not_for_optimizer(self): with flag_sandbox( {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd'}): # Set an Adam flag that isn't used in SGD. # We need to use `_parse_args` because that is the only way FLAGS is # notified that a non-default value is being used. bad_adam_flag = '{}_adam_beta_1'.format(TEST_CLIENT_FLAG_PREFIX) FLAGS._parse_args(args=['--{}=0.5'.format(bad_adam_flag)], known_only=True) with self.assertRaisesRegex( ValueError, r'Commandline flags for .*\[sgd\].*\'test_client_adam_beta_1\'.*' ): _ = utils_impl.create_optimizer_from_flags( TEST_CLIENT_FLAG_PREFIX) FLAGS[bad_adam_flag].unparse()
def _run_experiment(): """Data preprocessing and experiment execution.""" emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data( only_digits=FLAGS.digit_only_emnist) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return (dataset.map(reshape_emnist_element).shuffle( buffer_size=10000).repeat(FLAGS.client_epochs_per_round).batch( FLAGS.batch_size)) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(reshape_emnist_element).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) input_spec = example_dataset.element_spec def tff_model_fn(): keras_model = model_builder() return tff.learning.from_keras_model( keras_model, input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = random.sample( population=emnist_train.client_ids, k=FLAGS.train_clients_per_round) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] def evaluate_fn(state): compiled_keras_model = compiled_eval_keras_model() state.model.assign_weights_to(compiled_keras_model) eval_metrics = compiled_keras_model.evaluate(emnist_test, verbose=0) return { 'loss': eval_metrics[0], 'sparse_categorical_accuracy': eval_metrics[1], } tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([ (name, FLAGS[name].value) for name in hparam_flags ]) hparam_dict = utils_impl.remove_unused_flags('client', hparam_dict) metrics_hook = _MetricsHook(FLAGS.exp_name, FLAGS.root_output_dir, hparam_dict) client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('client') if FLAGS.server_optimizer == 'sgd': server_optimizer_fn = functools.partial( tf.keras.optimizers.SGD, learning_rate=FLAGS.server_learning_rate, momentum=FLAGS.server_momentum) elif FLAGS.server_optimizer == 'flars': server_optimizer_fn = functools.partial( flars_optimizer.FLARSOptimizer, learning_rate=FLAGS.server_learning_rate, momentum=FLAGS.server_momentum, max_ratio=FLAGS.max_ratio) else: raise ValueError('Optimizer %s is not supported.' % FLAGS.server_optimizer) _federated_averaging_training_loop( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_datasets_fn=client_datasets_fn, evaluate_fn=evaluate_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=metrics_hook)
def test_create_optimizer_from_flags_invalid_optimizer(self): FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo' with self.assertRaisesRegex(ValueError, 'not a valid optimizer'): _ = utils_impl.create_optimizer_from_flags(TEST_CLIENT_FLAG_PREFIX)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) emnist_task = 'digit_recognition' emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False) _, emnist_test = emnist_dataset.get_centralized_datasets( only_digits=False, emnist_task=emnist_task) train_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=FLAGS.client_epochs_per_round, batch_size=FLAGS.client_batch_size, emnist_task=emnist_task) input_spec = train_preprocess_fn.type_signature.result.element if FLAGS.model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=FLAGS.only_digits) elif FLAGS.model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=FLAGS.only_digits) elif FLAGS.model == '1m_cnn': model_builder = functools.partial( create_1m_cnn_model, only_digits=FLAGS.only_digits) else: raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model)) logging.info('Training model:') logging.info(model_builder().summary()) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] compression_dict = utils_impl.lookup_flag_values(compression_flags) dp_dict = utils_impl.lookup_flag_values(dp_flags) # Most logic for deciding what baseline to run is here. aggregation_factory = fl_utils.build_aggregator( compression_flags=compression_dict, dp_flags=dp_dict, num_clients=len(emnist_train.client_ids), num_clients_per_round=FLAGS.clients_per_round, num_rounds=FLAGS.total_rounds, client_template=model_builder().trainable_variables) def tff_model_fn(): return tff.learning.from_keras_model( keras_model=model_builder(), loss=loss_builder(), input_spec=input_spec, metrics=metrics_builder()) server_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('server') client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=tff_model_fn, server_optimizer_fn=server_optimizer_fn, client_weighting=tff.learning.ClientWeighting.UNIFORM, client_optimizer_fn=client_optimizer_fn, model_update_aggregation_factory=aggregation_factory) @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = emnist_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) training_process.get_model_weights = iterative_process.get_model_weights client_ids_fn = functools.partial( tff.simulation.build_uniform_sampling_fn( emnist_train.client_ids, replace=False, random_seed=FLAGS.client_datasets_random_seed), size=FLAGS.clients_per_round) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) evaluate_fn = tff.learning.build_federated_evaluation(tff_model_fn) def test_fn(state): return evaluate_fn( iterative_process.get_model_weights(state), [emnist_test]) def validation_fn(state, round_num): del round_num return evaluate_fn( iterative_process.get_model_weights(state), [emnist_test]) training_loop.run( iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)