def test_create_constant_client_lr_schedule_from_flags(self): with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'constant', }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 3.0, err=1e-5) self.assertNear(lr_schedule(1), 3.0, err=1e-5) self.assertNear(lr_schedule(105), 3.0, err=1e-5) self.assertNear(lr_schedule(1042), 3.0, err=1e-5) with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'constant', '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10 }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 0.3, err=1e-5) self.assertNear(lr_schedule(1), 0.6, err=1e-5) self.assertNear(lr_schedule(10), 3.0, err=1e-5) self.assertNear(lr_schedule(11), 3.0, err=1e-5) self.assertNear(lr_schedule(115), 3.0, err=1e-5) self.assertNear(lr_schedule(1052), 3.0, err=1e-5)
def from_flags( input_spec, model_builder: ModelBuilder, loss_builder: LossBuilder, metrics_builder: MetricsBuilder, client_weight_fn: Optional[ClientWeightFn] = None, ) -> tff.templates.IterativeProcess: """Builds a `tff.templates.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: input_spec: A value convertible to a `tff.Type`, representing the data which will be fed into the `tff.templates.IterativeProcess.next` function over the course of training. Generally, this can be found by accessing the `element_spec` attribute of a client `tf.data.Dataset`. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. Returns: A `tff.templates.IterativeProcess`. """ # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras # model. client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') model_input_spec = input_spec def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=model_input_spec, loss=loss_builder(), metrics=metrics_builder()) return fed_avg_schedule.build_fed_avg_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size, distort_train_images=FLAGS.cifar100_distort_train_images) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') if FLAGS.schedule == 'importance': fed_avg_schedule = importance_schedule elif FLAGS.schedule == 'loss': fed_avg_schedule = fed_loss else: fed_avg_schedule = fed_avg if FLAGS.schedule == 'importance': def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: factory = importance_aggregation_factory.ImportanceSamplingFactory( FLAGS.clients_per_round) weights_type = importance_aggregation_factory.weights_type_from_model_fn( model_fn) importance_aggregation_process = factory.create( value_type=weights_type, weight_type=tff.TensorType(tf.float32)) return importance_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, aggregation_process=importance_aggregation_process) elif FLAGS.schedule == 'loss': def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( total_clients=FLAGS.loss_pool_size, effective_num_clients=FLAGS.clients_per_round, model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn, aggregation_process=None) else: def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder task_args = _get_task_args() hparam_dict = _get_hparam_flags() # shared_args['prob_transmit'] = FLAGS.prob_transmit if FLAGS.task == 'cifar100': run_federated_fn = federated_cifar100.run_federated elif FLAGS.task == 'emnist_cr': run_federated_fn = federated_emnist.run_federated elif FLAGS.task == 'emnist_ae': run_federated_fn = federated_emnist_ae.run_federated elif FLAGS.task == 'shakespeare': run_federated_fn = federated_shakespeare.run_federated elif FLAGS.task == 'stackoverflow_nwp': run_federated_fn = federated_stackoverflow.run_federated elif FLAGS.task == 'stackoverflow_lr': run_federated_fn = federated_stackoverflow_lr.run_federated elif FLAGS.task == 'synthetic': run_federated_fn = federated_synthetic.run_federated else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) run_federated_fn(**shared_args, **task_args, beta=FLAGS.beta, hparam_dict=hparam_dict, schedule=FLAGS.schedule)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder task_args = _get_task_args() hparam_dict = _get_hparam_flags() if FLAGS.task == 'cifar100': run_federated_fn = federated_cifar100.run_federated elif FLAGS.task == 'emnist_cr': run_federated_fn = federated_emnist.run_federated elif FLAGS.task == 'emnist_ae': run_federated_fn = federated_emnist_ae.run_federated elif FLAGS.task == 'shakespeare': run_federated_fn = federated_shakespeare.run_federated elif FLAGS.task == 'stackoverflow_nwp': run_federated_fn = federated_stackoverflow.run_federated elif FLAGS.task == 'stackoverflow_lr': run_federated_fn = federated_stackoverflow_lr.run_federated else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) run_federated_fn(**shared_args, **task_args, hparam_dict=hparam_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') client_mixedin_schedule_fn = fed_pa_schedule.create_mixin_check_fn( name=FLAGS.client_mixin_check_scheme, num_mixin_epochs=FLAGS.client_mixin_epochs_per_round, start_round=FLAGS.client_mixin_check_start_round) client_update_delta_fn = fed_pa_schedule.create_update_delta_fn( name=FLAGS.client_update_delta_scheme, rho=FLAGS.client_shrinkage_rho) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model] ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ return fed_pa_schedule.build_fed_pa_process( model_fn=model_fn, client_update_epochs=FLAGS.client_epochs_per_round, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_mixedin_schedule_fn=client_mixedin_schedule_fn, client_update_delta_fn=client_update_delta_fn, mask_zeros_in_client_updates=FLAGS.mask_zeros_in_client_updates) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, # Since the number of epochs each client makes every round is handled # by the logic in client update functions, here we set it to 1. client_epochs_per_round=1, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size, distort_train_images=FLAGS.cifar100_distort_train_images) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
def test_create_inv_sqrt_client_lr_schedule_from_flags(self): with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 2.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_sqrt_decay', '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): True, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 2.0, err=1e-5) self.assertNear(lr_schedule(1), 2.0, err=1e-5) self.assertNear(lr_schedule(10), 0.603022689155, err=1e-5) self.assertNear(lr_schedule(19), 0.603022689155, err=1e-5) self.assertNear(lr_schedule(20), 0.436435780472, err=1e-5) with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 2.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_sqrt_decay', '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 0, '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 2.0, err=1e-5) self.assertNear(lr_schedule(3), 1.0, err=1e-5) self.assertNear(lr_schedule(99), 0.2, err=1e-5) self.assertNear(lr_schedule(399), 0.1, err=1e-5) with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 2.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_sqrt_decay', '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 0.2, err=1e-5) self.assertNear(lr_schedule(1), 0.4, err=1e-5) self.assertNear(lr_schedule(10), 2.0, err=1e-5) self.assertNear(lr_schedule(13), 1.0, err=1e-5) self.assertNear(lr_schedule(109), 0.2, err=1e-5) self.assertNear(lr_schedule(409), 0.1, err=1e-5)
def test_create_inv_lin_client_lr_schedule_from_flags(self): with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 5.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_lin_decay', '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): True, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 5.0, err=1e-5) self.assertNear(lr_schedule(1), 5.0, err=1e-5) self.assertNear(lr_schedule(10), 0.454545454545, err=1e-5) self.assertNear(lr_schedule(19), 0.454545454545, err=1e-5) self.assertNear(lr_schedule(20), 0.238095238095, err=1e-5) with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 5.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_lin_decay', '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 0, '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 5.0, err=1e-5) self.assertNear(lr_schedule(1), 2.5, err=1e-5) self.assertNear(lr_schedule(9), 0.5, err=1e-5) self.assertNear(lr_schedule(19), 0.25, err=1e-5) with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 5.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'inv_lin_decay', '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 10.0, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 0.5, err=1e-5) self.assertNear(lr_schedule(1), 1.0, err=1e-5) self.assertNear(lr_schedule(10), 5.0, err=1e-5) self.assertNear(lr_schedule(11), 2.5, err=1e-5) self.assertNear(lr_schedule(19), 0.5, err=1e-5) self.assertNear(lr_schedule(29), 0.25, err=1e-5)
def test_create_exp_decay_client_lr_schedule_from_flags(self): with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'exp_decay', '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 0.1, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): True, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 3.0, err=1e-5) self.assertNear(lr_schedule(3), 3.0, err=1e-5) self.assertNear(lr_schedule(10), 0.3, err=1e-5) self.assertNear(lr_schedule(19), 0.3, err=1e-5) self.assertNear(lr_schedule(20), 0.03, err=1e-5) with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'exp_decay', '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 0, '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 0.1, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 3.0, err=1e-5) self.assertNear(lr_schedule(1), 2.38298470417, err=1e-5) self.assertNear(lr_schedule(10), 0.3, err=1e-5) self.assertNear(lr_schedule(25), 0.00948683298, err=1e-5) with flag_sandbox({ '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): 3.0, '{}_lr_schedule'.format(TEST_CLIENT_FLAG_PREFIX): 'exp_decay', '{}_lr_warmup_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_steps'.format(TEST_CLIENT_FLAG_PREFIX): 10, '{}_lr_decay_rate'.format(TEST_CLIENT_FLAG_PREFIX): 0.1, '{}_lr_staircase'.format(TEST_CLIENT_FLAG_PREFIX): False, }): lr_schedule = optimizer_utils.create_lr_schedule_from_flags( TEST_CLIENT_FLAG_PREFIX) self.assertNear(lr_schedule(0), 0.3, err=1e-5) self.assertNear(lr_schedule(1), 0.6, err=1e-5) self.assertNear(lr_schedule(10), 3.0, err=1e-5) self.assertNear(lr_schedule(11), 2.38298470417, err=1e-5) self.assertNear(lr_schedule(20), 0.3, err=1e-5) self.assertNear(lr_schedule(35), 0.00948683298, err=1e-5)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size, distort_train_images=FLAGS.cifar100_distort_train_images) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)