def test_remove_unused_flags_with_empty_optimizer(self): hparam_dict = collections.OrderedDict([('optimizer', '')]) with self.assertRaisesRegex( ValueError, 'The flag optimizer was not set. ' 'Unable to determine the relevant optimizer.'): _ = utils_impl.remove_unused_flags(prefix=None, hparam_dict=hparam_dict)
def test_remove_flags_with_optimizers_sharing_a_prefix(self): hparam_dict = collections.OrderedDict([('client_optimizer', 'adamW'), ('client_adam_momentum', 0.3), ('client_adamW_momentum', 0.5)]) relevant_hparam_dict = utils_impl.remove_unused_flags( 'client', hparam_dict) expected_flag_names = ['client_optimizer', 'client_adamW_momentum'] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adamW') self.assertEqual(relevant_hparam_dict['client_adamW_momentum'], 0.5)
def test_remove_unused_flags_without_prefix(self): hparam_dict = collections.OrderedDict([('optimizer', 'sgd'), ('value', 0.1), ('sgd_momentum', 0.3), ('adam_momentum', 0.5)]) relevant_hparam_dict = utils_impl.remove_unused_flags( prefix=None, hparam_dict=hparam_dict) expected_flag_names = ['optimizer', 'value', 'sgd_momentum'] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd') self.assertEqual(relevant_hparam_dict['value'], 0.1) self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
def _run_experiment(): """Data preprocessing and experiment execution.""" emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data( only_digits=FLAGS.digit_only_emnist) def preprocess_train_dataset(dataset): """Preprocess training dataset.""" return (dataset.map(reshape_emnist_element).shuffle( buffer_size=10000).repeat(FLAGS.client_epochs_per_round).batch( FLAGS.batch_size)) def preprocess_test_dataset(dataset): """Preprocess testing dataset.""" return dataset.map(reshape_emnist_element).batch(100, drop_remainder=False) emnist_train = emnist_train.preprocess(preprocess_train_dataset) emnist_test = preprocess_test_dataset( emnist_test.create_tf_dataset_from_all_clients()) example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) input_spec = example_dataset.element_spec def tff_model_fn(): keras_model = model_builder() return tff.learning.from_keras_model( keras_model, input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) def client_datasets_fn(round_num): """Returns a list of client datasets.""" del round_num # Unused. sampled_clients = random.sample( population=emnist_train.client_ids, k=FLAGS.train_clients_per_round) return [ emnist_train.create_tf_dataset_for_client(client) for client in sampled_clients ] def evaluate_fn(state): compiled_keras_model = compiled_eval_keras_model() state.model.assign_weights_to(compiled_keras_model) eval_metrics = compiled_keras_model.evaluate(emnist_test, verbose=0) return { 'loss': eval_metrics[0], 'sparse_categorical_accuracy': eval_metrics[1], } tf.io.gfile.makedirs(FLAGS.root_output_dir) hparam_dict = collections.OrderedDict([ (name, FLAGS[name].value) for name in hparam_flags ]) hparam_dict = utils_impl.remove_unused_flags('client', hparam_dict) metrics_hook = _MetricsHook(FLAGS.exp_name, FLAGS.root_output_dir, hparam_dict) client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('client') if FLAGS.server_optimizer == 'sgd': server_optimizer_fn = functools.partial( tf.keras.optimizers.SGD, learning_rate=FLAGS.server_learning_rate, momentum=FLAGS.server_momentum) elif FLAGS.server_optimizer == 'flars': server_optimizer_fn = functools.partial( flars_optimizer.FLARSOptimizer, learning_rate=FLAGS.server_learning_rate, momentum=FLAGS.server_momentum, max_ratio=FLAGS.max_ratio) else: raise ValueError('Optimizer %s is not supported.' % FLAGS.server_optimizer) _federated_averaging_training_loop( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_datasets_fn=client_datasets_fn, evaluate_fn=evaluate_fn, total_rounds=FLAGS.total_rounds, rounds_per_eval=FLAGS.rounds_per_eval, metrics_hook=metrics_hook)
def test_remove_unused_flags_without_optimizer_flag(self): hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'), ('client_sgd_momentum', 0.3)]) with self.assertRaisesRegex( ValueError, 'The flag client_optimizer was not defined.'): _ = utils_impl.remove_unused_flags('client', hparam_dict)