def from_flags( input_spec, model_builder: ModelBuilder, loss_builder: LossBuilder, metrics_builder: MetricsBuilder, client_weight_fn: Optional[ClientWeightFn] = None, ) -> tff.templates.IterativeProcess: """Builds a `tff.templates.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: input_spec: A value convertible to a `tff.Type`, representing the data which will be fed into the `tff.templates.IterativeProcess.next` function over the course of training. Generally, this can be found by accessing the `element_spec` attribute of a client `tf.data.Dataset`. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. Returns: A `tff.templates.IterativeProcess`. """ # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras # model. client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') model_input_spec = input_spec def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=model_input_spec, loss=loss_builder(), metrics=metrics_builder()) return fed_avg_schedule.build_fed_avg_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn)
def test_create_optimizer_fn_with_no_learning_rate(self): with flag_sandbox({ '{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd', '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX): None }): with self.assertRaisesRegex(ValueError, 'Learning rate'): optimizer_utils.create_optimizer_fn_from_flags( TEST_CLIENT_FLAG_PREFIX)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return tff.learning.build_federated_averaging_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, use_experimental_simulation_loop=True) dataset_type = dataset.DatasetType.GLD23K if FLAGS.dataset_type == 'gld160k': dataset_type = dataset.DatasetType.GLD160K federated_main.run_federated( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, max_elements_per_user=FLAGS.max_elements_per_user, image_size=FLAGS.image_size, num_groups=FLAGS.num_groups, total_rounds=FLAGS.total_rounds, dataset_type=dataset_type, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, dropout_prob=FLAGS.dropout_prob, client_datasets_random_seed=FLAGS.client_datasets_random_seed, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, hparam_dict=get_hparam_flags())
def test_create_optimizer_fn_from_flags_flags_set_not_for_optimizer(self): with flag_sandbox( {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd'}): # Set an Adam flag that isn't used in SGD. # We need to use `_parse_args` because that is the only way FLAGS is # notified that a non-default value is being used. bad_adam_flag = '{}_adam_beta_1'.format(TEST_CLIENT_FLAG_PREFIX) FLAGS._parse_args(args=['--{}=0.5'.format(bad_adam_flag)], known_only=True) with self.assertRaisesRegex( ValueError, r'Commandline flags for .*\[sgd\].*\'test_client_adam_beta_1\'.*' ): optimizer_utils.create_optimizer_fn_from_flags( TEST_CLIENT_FLAG_PREFIX) FLAGS[bad_adam_flag].unparse()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() hparams_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) common_args = collections.OrderedDict([ ('optimizer', optimizer), ('experiment_name', FLAGS.experiment_name), ('root_output_dir', FLAGS.root_output_dir), ('num_epochs', FLAGS.num_epochs), ('batch_size', FLAGS.batch_size), ('decay_epochs', FLAGS.decay_epochs), ('lr_decay', FLAGS.lr_decay), ('hparams_dict', hparams_dict), ]) if FLAGS.task == 'cifar10': centralized_cifar10.run_centralized(**common_args, crop_size=FLAGS.cifar100_crop_size) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS))
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') optimizer = optimizer_utils.create_optimizer_fn_from_flags('centralized')() hparams_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in hparam_flags]) common_args = collections.OrderedDict([ ('optimizer', optimizer), ('experiment_name', FLAGS.experiment_name), ('root_output_dir', FLAGS.root_output_dir), ('num_epochs', FLAGS.num_epochs), ('batch_size', FLAGS.batch_size), ('decay_epochs', FLAGS.decay_epochs), ('lr_decay', FLAGS.lr_decay), ('hparams_dict', hparams_dict), ]) if FLAGS.task == 'cifar100': centralized_cifar100.run_centralized( **common_args, crop_size=FLAGS.cifar100_crop_size) elif FLAGS.task == 'emnist_cr': centralized_emnist.run_centralized(**common_args, emnist_model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': centralized_emnist_ae.run_centralized(**common_args) elif FLAGS.task == 'shakespeare': centralized_shakespeare.run_centralized( **common_args, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value centralized_stackoverflow.run_centralized(**common_args, **so_nwp_flags) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value centralized_stackoverflow_lr.run_centralized(**common_args, **so_lr_flags) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS))
def test_create_server_optimizer_from_flags(self, optimizer_name, optimizer_cls): commandline_set_learning_rate = 100.0 with flag_sandbox({ '{}_optimizer'.format(TEST_SERVER_FLAG_PREFIX): optimizer_name, '{}_learning_rate'.format(TEST_SERVER_FLAG_PREFIX): commandline_set_learning_rate }): custom_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( TEST_SERVER_FLAG_PREFIX) custom_optimizer = custom_optimizer_fn() self.assertIsInstance(custom_optimizer, optimizer_cls) self.assertEqual(custom_optimizer.get_config()['learning_rate'], commandline_set_learning_rate) custom_optimizer_with_arg = custom_optimizer_fn(11.0) self.assertIsInstance(custom_optimizer_with_arg, optimizer_cls) self.assertEqual( custom_optimizer_with_arg.get_config()['learning_rate'], 11.0)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') hparams_dict = utils_impl.lookup_flag_values(hparam_flags) hparams_dict = optimizer_utils.remove_unused_flags('centralized', hparams_dict) centralized_main.run_centralized( optimizer_utils.create_optimizer_fn_from_flags('centralized')(), FLAGS.num_epochs, FLAGS.batch_size, vocab_size=FLAGS.vocab_size, d_embed=FLAGS.d_embed, d_model=FLAGS.d_model, d_hidden=FLAGS.d_hidden, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, dropout=FLAGS.dropout, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, max_batches=FLAGS.max_batches, hparams_dict=hparams_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') hparams_dict = utils_impl.lookup_flag_values(hparam_flags) hparams_dict = optimizer_utils.remove_unused_flags('centralized', hparams_dict) dataset_type = dataset.DatasetType.GLD23K if FLAGS.dataset_type == 'gld160k': dataset_type = dataset.DatasetType.GLD160K centralized_main.run_centralized( optimizer=optimizer_utils.create_optimizer_fn_from_flags( 'centralized')(), image_size=FLAGS.image_size, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size, num_groups=FLAGS.num_groups, dataset_type=dataset_type, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, dropout_prob=FLAGS.dropout_prob, hparams_dict=hparams_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model] ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ if FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp': def client_weight_fn(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) else: client_weight_fn = None return fed_avg_local_adaptivity.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=FLAGS.client_learning_rate, server_optimizer_fn=server_optimizer_fn, server_lr=FLAGS.server_learning_rate, client_weight_fn=client_weight_fn, correction=FLAGS.correction_type) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size, distort_train_images=FLAGS.cifar100_distort_train_images) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size, distort_train_images=FLAGS.cifar100_distort_train_images) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder if FLAGS.task == 'cifar100': hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size federated_cifar100.run_federated( **shared_args, crop_size=FLAGS.cifar100_crop_size, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated( **shared_args, model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **shared_args, sequence_length=FLAGS.shakespeare_sequence_length, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated( **shared_args, **so_nwp_flags, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated( **shared_args, **so_lr_flags, hparam_dict=hparam_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') if FLAGS.schedule == 'importance': fed_avg_schedule = importance_schedule elif FLAGS.schedule == 'loss': fed_avg_schedule = fed_loss else: fed_avg_schedule = fed_avg if FLAGS.schedule == 'importance': def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: factory = importance_aggregation_factory.ImportanceSamplingFactory( FLAGS.clients_per_round) weights_type = importance_aggregation_factory.weights_type_from_model_fn( model_fn) importance_aggregation_process = factory.create( value_type=weights_type, weight_type=tff.TensorType(tf.float32)) return importance_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, aggregation_process=importance_aggregation_process) elif FLAGS.schedule == 'loss': def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( total_clients=FLAGS.loss_pool_size, effective_num_clients=FLAGS.clients_per_round, model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn, aggregation_process=None) else: def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder task_args = _get_task_args() hparam_dict = _get_hparam_flags() # shared_args['prob_transmit'] = FLAGS.prob_transmit if FLAGS.task == 'cifar100': run_federated_fn = federated_cifar100.run_federated elif FLAGS.task == 'emnist_cr': run_federated_fn = federated_emnist.run_federated elif FLAGS.task == 'emnist_ae': run_federated_fn = federated_emnist_ae.run_federated elif FLAGS.task == 'shakespeare': run_federated_fn = federated_shakespeare.run_federated elif FLAGS.task == 'stackoverflow_nwp': run_federated_fn = federated_stackoverflow.run_federated elif FLAGS.task == 'stackoverflow_lr': run_federated_fn = federated_stackoverflow_lr.run_federated elif FLAGS.task == 'synthetic': run_federated_fn = federated_synthetic.run_federated else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) run_federated_fn(**shared_args, **task_args, beta=FLAGS.beta, hparam_dict=hparam_dict, schedule=FLAGS.schedule)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) emnist_train, emnist_test = emnist_dataset.get_emnist_datasets( FLAGS.client_batch_size, FLAGS.client_epochs_per_round, only_digits=False) if FLAGS.model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif FLAGS.model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError('Cannot handle model flag [{!s}].'.format( FLAGS.model)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] if FLAGS.uniform_weighting: def client_weight_fn(local_outputs): del local_outputs return 1.0 else: client_weight_fn = None # Defaults to the number of examples per client. def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=emnist_test.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) dp_query = tff.utils.build_dp_query( clip=FLAGS.clip, noise_multiplier=FLAGS.noise_multiplier, expected_total_weight=FLAGS.clients_per_round, adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate, target_unclipped_quantile=FLAGS.target_unclipped_quantile, clipped_count_budget_allocation=FLAGS. clipped_count_budget_allocation, expected_clients_per_round=FLAGS.clients_per_round, per_vector_clipping=FLAGS.per_vector_clipping, model=model_fn()) weights_type = tff.learning.framework.weights_type_from_model(model_fn) aggregation_process = tff.utils.build_dp_aggregate_process( weights_type.trainable, dp_query) else: aggregation_process = None server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, client_optimizer_fn=client_optimizer_fn, aggregation_process=aggregation_process) client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, hparam_dict=hparam_dict, **training_loop_dict)
def test_create_optimizer_fn_from_flags_invalid_optimizer(self): FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo' with self.assertRaisesRegex(ValueError, 'not a valid optimizer'): optimizer_utils.create_optimizer_fn_from_flags( TEST_CLIENT_FLAG_PREFIX)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') if FLAGS.task == 'stackoverflow_nwp_finetune': if not FLAGS.global_variables_only: raise ValueError('`FLAGS.global_variables_only` must be True for ' 'a `stackoverflow_nwp_finetune` task.') if not FLAGS.client_epochs_per_round: raise ValueError('`FLAGS.client_epochs_per_round` must be set for ' 'a `stackoverflow_nwp_finetune` task.') finetune_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'finetune') else: reconstruction_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'reconstruction') def iterative_process_builder( model_fn: Callable[[], reconstruction_model.ReconstructionModel], loss_fn: Callable[[], List[tf.keras.losses.Loss]], metrics_fn: Optional[Callable[[], List[tf.keras.metrics.Metric]]] = None, client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, dataset_split_fn_builder: Callable[ ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils. build_dataset_split_fn, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. For a `stackoverflow_nwp_finetune` task, the `model_fn` must return a model that has only global variables, and the argument `dataset_split_fn_builder` is ignored. The returned iterative process is basically the same as the one created by the standard `tff.learning.build_federated_averaging_process`. For other tasks, the returned iterative process performs the federated reconstruction algorithm defined by `training_process.build_federated_reconstruction_process`. Args: model_fn: A no-arg function returning a `reconstruction_model.ReconstructionModel`. The returned model must have only global variables for a `stackoverflow_nwp_finetune` task. loss_fn: A no-arg function returning a list of `tf.keras.losses.Loss`. metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`. client_weight_fn: Optional function that takes the local model's output, and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. If DP is used, this argument is ignored, and uniform client weighting is used. dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used to split the examples into a reconstruction, and post-reconstruction set. Ignored for a `stackoverflow_nwp_finetune` task. Raises: ValueError: if `model_fn` returns a model with local variables for a `stackoverflow_nwp_finetune` task. Returns: A `tff.templates.IterativeProcess`. """ # Get aggregation factory for DP, if needed. aggregation_factory = None client_weighting = client_weight_fn if FLAGS.dp_noise_multiplier is not None: aggregation_factory = tff.learning.dp_aggregator( noise_multiplier=FLAGS.dp_noise_multiplier, clients_per_round=float(FLAGS.clients_per_round), zeroing=FLAGS.dp_zeroing) # DP is only implemented for uniform weighting. client_weighting = lambda _: 1.0 if FLAGS.task == 'stackoverflow_nwp_finetune': if not reconstruction_utils.has_only_global_variables(model_fn()): raise ValueError( '`model_fn` should return a model with only global variables. ' ) def fake_dataset_split_fn( client_dataset: tf.data.Dataset, round_num: tf.Tensor ) -> Tuple[tf.data.Dataset, tf.data.Dataset]: del round_num return client_dataset.repeat(0), client_dataset.repeat( FLAGS.client_epochs_per_round) return training_process.build_federated_reconstruction_process( model_fn=model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, server_optimizer_fn=lambda: server_optimizer_fn( FLAGS.server_learning_rate), client_optimizer_fn=lambda: client_optimizer_fn( FLAGS.client_learning_rate), dataset_split_fn=fake_dataset_split_fn, client_weight_fn=client_weighting, aggregation_factory=aggregation_factory) return training_process.build_federated_reconstruction_process( model_fn=model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, server_optimizer_fn=lambda: server_optimizer_fn( FLAGS.server_learning_rate), client_optimizer_fn=lambda: client_optimizer_fn( FLAGS.client_learning_rate), reconstruction_optimizer_fn=functools.partial( reconstruction_optimizer_fn, FLAGS.reconstruction_learning_rate), dataset_split_fn=dataset_split_fn_builder( recon_epochs_max=FLAGS.recon_epochs_max, recon_epochs_constant=FLAGS.recon_epochs_constant, recon_steps_max=FLAGS.recon_steps_max, post_recon_epochs=FLAGS.post_recon_epochs, post_recon_steps_max=FLAGS.post_recon_steps_max, split_dataset=FLAGS.split_dataset), evaluate_reconstruction=FLAGS.evaluate_reconstruction, jointly_train_variables=FLAGS.jointly_train_variables, client_weight_fn=client_weighting, aggregation_factory=aggregation_factory) def evaluation_computation_builder( model_fn: Callable[[], reconstruction_model.ReconstructionModel], loss_fn: Callable[[], tf.losses.Loss], metrics_fn: Callable[[], List[tf.metrics.Metric]], dataset_split_fn_builder: Callable[ ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils. build_dataset_split_fn, ) -> tff.Computation: """Creates a `tff.Computation` for federated evaluation. For a `stackoverflow_nwp_finetune` task, the returned `tff.Computation` is created by `federated_evaluation.build_federated_finetune_evaluation`. For other tasks, the returned `tff.Computation` is given by `evaluation_computation.build_federated_reconstruction_evaluation`. Args: model_fn: A no-arg function that returns a `ReconstructionModel`. The returned model must have only global variables for a `stackoverflow_nwp_finetune` task. This method must *not* capture Tensorflow tensors or variables and use them. Must be constructed entirely from scratch on each invocation, returning the same model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to evaluate the model. The final loss metric is the example-weighted mean loss across batches (and across clients). metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to use to evaluate the model. The final metrics are the example-weighted mean metrics across batches (and across clients). dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used to split the examples into a reconstruction set (which is used as a fine-tuning set for a `stackoverflow_nwp_finetune` task), and an evaluation set. Returns: A `tff.Computation` for federated evaluation. """ # For a `stackoverflow_nwp_finetune` task, the first dataset returned by # `dataset_split_fn` is used for fine-tuning global variables. For other # tasks, the first dataset is used for reconstructing local variables. dataset_split_fn = dataset_split_fn_builder( recon_epochs_max=FLAGS.recon_epochs_max, recon_epochs_constant=FLAGS.recon_epochs_constant, recon_steps_max=FLAGS.recon_steps_max, post_recon_epochs=FLAGS.post_recon_epochs, post_recon_steps_max=FLAGS.post_recon_steps_max, # Getting meaningful evaluation metrics requires splitting the data. split_dataset=True) if FLAGS.task == 'stackoverflow_nwp_finetune': return federated_evaluation.build_federated_finetune_evaluation( model_fn=model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, finetune_optimizer_fn=functools.partial( finetune_optimizer_fn, FLAGS.finetune_learning_rate), dataset_split_fn=dataset_split_fn) return evaluation_computation.build_federated_reconstruction_evaluation( model_fn=model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=functools.partial( reconstruction_optimizer_fn, FLAGS.reconstruction_learning_rate), dataset_split_fn=dataset_split_fn) # Shared args, useful to support more tasks. shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder shared_args[ 'evaluation_computation_builder'] = evaluation_computation_builder task_args = _get_task_args() _write_hparam_flags() if FLAGS.task in ['stackoverflow_nwp', 'stackoverflow_nwp_finetune']: run_federated_fn = federated_stackoverflow.run_federated elif FLAGS.task == 'movielens_mf': run_federated_fn = federated_movielens.run_federated else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) run_federated_fn(**shared_args, **task_args)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') client_mixedin_schedule_fn = fed_pa_schedule.create_mixin_check_fn( name=FLAGS.client_mixin_check_scheme, num_mixin_epochs=FLAGS.client_mixin_epochs_per_round, start_round=FLAGS.client_mixin_check_start_round) client_update_delta_fn = fed_pa_schedule.create_update_delta_fn( name=FLAGS.client_update_delta_scheme, rho=FLAGS.client_shrinkage_rho) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model] ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ return fed_pa_schedule.build_fed_pa_process( model_fn=model_fn, client_update_epochs=FLAGS.client_epochs_per_round, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_mixedin_schedule_fn=client_mixedin_schedule_fn, client_update_delta_fn=client_update_delta_fn, mask_zeros_in_client_updates=FLAGS.mask_zeros_in_client_updates) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, # Since the number of epochs each client makes every round is handled # by the logic in client update functions, here we set it to 1. client_epochs_per_round=1, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size, distort_train_images=FLAGS.cifar100_distort_train_images) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) emnist_train, _ = emnist_dataset.get_federated_datasets( train_client_batch_size=FLAGS.client_batch_size, train_client_epochs_per_round=FLAGS.client_epochs_per_round, only_digits=False) _, emnist_test = emnist_dataset.get_centralized_datasets() if FLAGS.model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif FLAGS.model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] if FLAGS.uniform_weighting: client_weighting = tff.learning.ClientWeighting.UNIFORM else: client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=emnist_test.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.') if FLAGS.noise_multiplier <= 0: raise ValueError('noise_multiplier must be positive if DP is enabled.') if FLAGS.clip is None or FLAGS.clip <= 0: raise ValueError('clip must be positive if DP is enabled.') if not FLAGS.adaptive_clip_learning_rate: aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, clip=FLAGS.clip) else: if FLAGS.adaptive_clip_learning_rate <= 0: raise ValueError('adaptive_clip_learning_rate must be positive if ' 'adaptive clipping is enabled.') aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, initial_l2_norm_clip=FLAGS.clip, target_unclipped_quantile=FLAGS.target_unclipped_quantile, learning_rate=FLAGS.adaptive_clip_learning_rate) else: if FLAGS.uniform_weighting: aggregation_factory = tff.aggregators.UnweightedMeanFactory() else: aggregation_factory = tff.aggregators.MeanFactory() server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weighting=client_weighting, client_optimizer_fn=client_optimizer_fn, model_update_aggregation_factory=aggregation_factory) client_datasets_fn = training_utils.build_client_datasets_fn( emnist_train, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights) logging.info('Training model:') logging.info(model_builder().summary()) # Log hyperparameters to CSV hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file) training_loop.run( iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder task_args = _get_task_args() hparam_dict = _get_hparam_flags() if FLAGS.task == 'cifar100': run_federated_fn = federated_cifar100.run_federated elif FLAGS.task == 'emnist_cr': run_federated_fn = federated_emnist.run_federated elif FLAGS.task == 'emnist_ae': run_federated_fn = federated_emnist_ae.run_federated elif FLAGS.task == 'shakespeare': run_federated_fn = federated_shakespeare.run_federated elif FLAGS.task == 'stackoverflow_nwp': run_federated_fn = federated_stackoverflow.run_federated elif FLAGS.task == 'stackoverflow_lr': run_federated_fn = federated_stackoverflow_lr.run_federated else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) run_federated_fn(**shared_args, **task_args, hparam_dict=hparam_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ logging.info('Trainable weights:') for weight in model_fn().weights.trainable: logging.info('name: %s shape: %s', weight.name, weight.shape) if FLAGS.uniform_weighting: client_weighting = tff.learning.ClientWeighting.UNIFORM elif FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp': def client_weighting(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) else: client_weighting = None if FLAGS.noise_multiplier is None: if FLAGS.uniform_weighting: aggregation_factory = tff.aggregators.UnweightedMeanFactory() else: aggregation_factory = tff.aggregators.MeanFactory() if FLAGS.clip is not None: if FLAGS.clip <= 0: raise ValueError('clip must be positive if clipping is enabled.') if FLAGS.adaptive_clip_learning_rate is None: clip = FLAGS.clip else: if FLAGS.adaptive_clip_learning_rate <= 0: raise ValueError('adaptive_clip_learning_rate must be positive if ' 'adaptive clipping is enabled.') clip = tff.aggregators.PrivateQuantileEstimationProcess.no_noise( initial_estimate=FLAGS.clip, target_quantile=FLAGS.target_unclipped_quantile, learning_rate=FLAGS.adaptive_clip_learning_rate) aggregation_factory = tff.aggregators.clipping_factory( clip, aggregation_factory) else: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.') if FLAGS.noise_multiplier <= 0: raise ValueError('noise_multiplier must be positive if DP is enabled.') if FLAGS.clip is None or FLAGS.clip <= 0: raise ValueError('clip must be positive if DP is enabled.') if FLAGS.adaptive_clip_learning_rate is None: aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, clip=FLAGS.clip) else: if FLAGS.adaptive_clip_learning_rate <= 0: raise ValueError('adaptive_clip_learning_rate must be positive if ' 'adaptive clipping is enabled.') aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, initial_l2_norm_clip=FLAGS.clip, target_unclipped_quantile=FLAGS.target_unclipped_quantile, learning_rate=FLAGS.adaptive_clip_learning_rate) return tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weighting=client_weighting, client_optimizer_fn=client_optimizer_fn, model_update_aggregation_factory=aggregation_factory) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training(task_spec) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training(task_spec) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training(task_spec) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training(task_spec) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training(task_spec) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run( iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tff.backends.native.set_local_execution_context(max_fanout=10) model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( FLAGS.vocab_size) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), ] train_dataset, _ = stackoverflow_word_prediction.get_federated_datasets( vocab_size=FLAGS.vocab_size, train_client_batch_size=FLAGS.client_batch_size, train_client_epochs_per_round=FLAGS.client_epochs_per_round, max_sequence_length=FLAGS.sequence_length, max_elements_per_train_client=FLAGS.max_elements_per_user) _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=FLAGS.vocab_size, max_sequence_length=FLAGS.sequence_length, num_validation_examples=FLAGS.num_validation_examples) if FLAGS.uniform_weighting: client_weighting = tff.learning.ClientWeighting.UNIFORM else: client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=validation_dataset.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) if FLAGS.noise_multiplier <= 0: raise ValueError( 'noise_multiplier must be positive if DP is enabled.') if FLAGS.clip is None or FLAGS.clip <= 0: raise ValueError('clip must be positive if DP is enabled.') if not FLAGS.adaptive_clip_learning_rate: aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, clip=FLAGS.clip) else: if FLAGS.adaptive_clip_learning_rate <= 0: raise ValueError( 'adaptive_clip_learning_rate must be positive if ' 'adaptive clipping is enabled.') aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, initial_l2_norm_clip=FLAGS.clip, target_unclipped_quantile=FLAGS.target_unclipped_quantile, learning_rate=FLAGS.adaptive_clip_learning_rate) else: if FLAGS.uniform_weighting: aggregation_factory = tff.aggregators.UnweightedMeanFactory() else: aggregation_factory = tff.aggregators.MeanFactory() server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weighting=client_weighting, client_optimizer_fn=client_optimizer_fn, model_update_aggregation_factory=aggregation_factory) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda state, round_num: evaluate_fn(state.model) evaluate_test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = lambda state: evaluate_test_fn(state.model) logging.info('Training model:') logging.info(model_builder().summary()) # Log hyperparameters to CSV hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, test_fn=test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tff.backends.native.set_local_execution_context(max_fanout=10) model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), ] datasets = stackoverflow_dataset.construct_word_level_datasets( FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, FLAGS.sequence_length, FLAGS.max_elements_per_user, FLAGS.num_validation_examples) train_dataset, validation_dataset, test_dataset = datasets if FLAGS.uniform_weighting: def client_weight_fn(local_outputs): del local_outputs return 1.0 else: def client_weight_fn(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=validation_dataset.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) dp_query = tff.utils.build_dp_query( clip=FLAGS.clip, noise_multiplier=FLAGS.noise_multiplier, expected_total_weight=FLAGS.clients_per_round, adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate, target_unclipped_quantile=FLAGS.target_unclipped_quantile, clipped_count_budget_allocation=FLAGS. clipped_count_budget_allocation, expected_clients_per_round=FLAGS.clients_per_round) weights_type = tff.learning.framework.weights_type_from_model(model_fn) aggregation_process = tff.utils.build_dp_aggregate_process( weights_type.trainable, dp_query) else: aggregation_process = None server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, client_optimizer_fn=client_optimizer_fn, aggregation_process=aggregation_process) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, hparam_dict=hparam_dict, **training_loop_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size, distort_train_images=FLAGS.cifar100_distort_train_images) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) _write_hparam_flags() training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)