def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn, dataset_preprocess_comp=dataset_preprocess_comp) assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model common_args = collections.OrderedDict([ ('iterative_process_builder', iterative_process_builder), ('assign_weights_fn', assign_weights_fn), ('client_epochs_per_round', FLAGS.client_epochs_per_round), ('client_batch_size', FLAGS.client_batch_size), ('clients_per_round', FLAGS.clients_per_round), ('client_datasets_random_seed', FLAGS.client_datasets_random_seed) ]) if FLAGS.task == 'cifar100': federated_cifar100.run_federated( **common_args, crop_size=FLAGS.cifar100_crop_size) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated( **common_args, emnist_model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**common_args) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **common_args, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated(**common_args, **so_nwp_flags) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated(**common_args, **so_lr_flags) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS))
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder if FLAGS.task == 'cifar100': hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size federated_cifar100.run_federated(**shared_args, crop_size=FLAGS.cifar100_crop_size, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated(**shared_args, emnist_model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **shared_args, sequence_length=FLAGS.shakespeare_sequence_length, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated(**shared_args, **so_nwp_flags, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated(**shared_args, **so_lr_flags, hparam_dict=hparam_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) assign_weights_fn = fed_avg_schedule.ServerState.assign_weights_to_keras_model hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder shared_args['assign_weights_fn'] = assign_weights_fn if FLAGS.task == 'cifar100': hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size federated_cifar100.run_federated( **shared_args, crop_size=FLAGS.cifar100_crop_size, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated( **shared_args, emnist_model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **shared_args, sequence_length=FLAGS.shakespeare_sequence_length, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated( **shared_args, **so_nwp_flags, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated( **shared_args, **so_lr_flags, hparam_dict=hparam_dict) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS))