def test_create_optimizer_fn_with_no_learning_rate(self): with flag_sandbox({ '{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd', '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): None }): with self.assertRaisesRegex(ValueError, 'Learning rate'): optimizer_utils.create_optimizer_fn_from_flags( TEST_CLIENT_FLAG_PREFIX)
def test_create_optimizer_fn_from_flags_flags_set_not_for_optimizer(self): with flag_sandbox( {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd'}): # Set an Adam flag that isn't used in SGD. # We need to use `_parse_args` because that is the only way FLAGS is # notified that a non-default value is being used. bad_adam_flag = '{}_adam_beta_1'.format(TEST_CLIENT_FLAG_PREFIX) FLAGS._parse_args(args=['--{}=0.5'.format(bad_adam_flag)], known_only=True) with self.assertRaisesRegex( ValueError, r'Commandline flags for .*\[sgd\].*\'test_client_adam_beta_1\'.*' ): optimizer_utils.create_optimizer_fn_from_flags( TEST_CLIENT_FLAG_PREFIX) FLAGS[bad_adam_flag].unparse()
def test_create_server_optimizer_from_flags(self, optimizer_name, optimizer_cls): commandline_set_learning_rate = 100.0 with flag_sandbox({ '{}_optimizer'.format(TEST_SERVER_FLAG_PREFIX): optimizer_name, '{}_learning_rate'.format(TEST_SERVER_FLAG_PREFIX): commandline_set_learning_rate }): custom_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( TEST_SERVER_FLAG_PREFIX) custom_optimizer = custom_optimizer_fn() self.assertIsInstance(custom_optimizer, optimizer_cls) self.assertEqual(custom_optimizer.get_config()['learning_rate'], commandline_set_learning_rate) custom_optimizer_with_arg = custom_optimizer_fn(11.0) self.assertIsInstance(custom_optimizer_with_arg, optimizer_cls) self.assertEqual( custom_optimizer_with_arg.get_config()['learning_rate'], 11.0)
def test_create_optimizer_fn_from_flags_invalid_optimizer(self): FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo' with self.assertRaisesRegex(ValueError, 'not a valid optimizer'): optimizer_utils.create_optimizer_fn_from_flags( TEST_CLIENT_FLAG_PREFIX)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model] ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ if FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp': def client_weight_fn(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) else: client_weight_fn = None return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, tau=FLAGS.tau, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size, distort_train_images=FLAGS.cifar100_distort_train_images, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training( task_spec, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples, cache_dir=FLAGS.cache_dir) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() hparams_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) common_args = collections.OrderedDict([ ('optimizer', optimizer), ('experiment_name', FLAGS.experiment_name), ('root_output_dir', FLAGS.root_output_dir), ('num_epochs', FLAGS.num_epochs), ('batch_size', FLAGS.batch_size), ('decay_epochs', FLAGS.decay_epochs), ('lr_decay', FLAGS.lr_decay), ('hparams_dict', hparams_dict), ]) if FLAGS.task == 'cifar100': centralized_cifar100.run_centralized( **common_args, crop_size=FLAGS.cifar100_crop_size, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'emnist_cr': centralized_emnist.run_centralized(**common_args, emnist_model=FLAGS.emnist_cr_model, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'emnist_ae': centralized_emnist_ae.run_centralized(**common_args, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'shakespeare': centralized_shakespeare.run_centralized( **common_args, sequence_length=FLAGS.shakespeare_sequence_length, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value centralized_stackoverflow.run_centralized(**common_args, **so_nwp_flags, cache_dir=FLAGS.cache_dir) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value centralized_stackoverflow_lr.run_centralized(**common_args, **so_lr_flags, cache_dir=FLAGS.cache_dir) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS))