def test_fed_sgd_without_decay_decreases_loss(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.0, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state, train_outputs = self._run_rounds(iterative_process, 5) self.assertLess(train_outputs[-1]['before_training']['loss'], train_outputs[0]['before_training']['loss']) self.assertLess(train_outputs[-1]['during_training']['loss'], train_outputs[0]['during_training']['loss']) self.assertNear(state.client_lr_callback.learning_rate, 0.0, 1e-8) self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
def test_comparable_to_fed_avg(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) reference_iterative_process = tff.learning.build_federated_averaging_process( _uncompiled_model_builder, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0)) _, train_outputs = self._run_rounds(iterative_process, 5) _, reference_train_outputs = self._run_rounds_tff_fedavg( reference_iterative_process, 5) for i in range(5): self.assertAllClose(train_outputs[i]['during_training']['loss'], reference_train_outputs[i]['loss'], 1e-4)
def test_iterative_process_type_signature(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def test_raises_bad_decay_factor(self): with self.assertRaises(ValueError): callbacks.create_reduce_lr_on_plateau(learning_rate=0.1, decay_factor=2.0, cooldown=0) with self.assertRaises(ValueError): callbacks.create_reduce_lr_on_plateau(learning_rate=0.1, decay_factor=-1.0)
def test_small_lr_comparable_zero_lr(self): client_lr_callback1 = callbacks.create_reduce_lr_on_plateau( learning_rate=0.0, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) client_lr_callback2 = callbacks.create_reduce_lr_on_plateau( learning_rate=1e-8, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process1 = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback1, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterative_process2 = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback2, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state1, train_outputs1 = self._run_rounds(iterative_process1, 5) state2, train_outputs2 = self._run_rounds(iterative_process2, 5) self.assertAllClose(state1.model.trainable, state2.model.trainable, 1e-4) self.assertAllClose(train_outputs1, train_outputs2, 1e-4)
def test_get_model_weights(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, window_size=1, patience=1, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, window_size=1, patience=1, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state = iterative_process.initialize() self.assertIsInstance(iterative_process.get_model_weights(state), tff.learning.ModelWeights) self.assertAllClose( state.model.trainable, iterative_process.get_model_weights(state).trainable) state, _ = self._run_rounds(iterative_process, 5) self.assertIsInstance(iterative_process.get_model_weights(state), tff.learning.ModelWeights) self.assertAllClose( state.model.trainable, iterative_process.get_model_weights(state).trainable)
def test_min_lr(self): lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=0.1, decay_factor=0.5, min_lr=0.2, minimize=False, window_size=1, patience=1, cooldown=0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.learning_rate, 0.2) for i in range(5): x = -float(i) lr_callback = lr_callback.update(x) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, 0.0) self.assertEqual(lr_callback.learning_rate, 0.2) self.assertEqual(lr_callback.wait, i + 1)
def test_lr_decay_after_patience_rounds(self): lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=1.0, decay_factor=0.5, minimize=False, window_size=3, patience=5, cooldown=0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.metrics_window, [0.0, 0.0, 0.0]) for i in range(4): lr_callback = lr_callback.update(-1.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, 0.0) self.assertEqual(lr_callback.learning_rate, 1.0) self.assertEqual(lr_callback.wait, i + 1) lr_callback = lr_callback.update(-1.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, 0.0) self.assertEqual(lr_callback.learning_rate, 0.5) self.assertEqual(lr_callback.wait, 0)
def test_window_with_inf_values(self): lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=1.0, decay_factor=0.5, minimize=True, window_size=3, patience=1, cooldown=0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.metrics_window, [np.Inf for _ in range(3)]) for i in range(2): lr_callback = lr_callback.update(3.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, np.Inf) self.assertEqual(lr_callback.learning_rate, (0.5)**(i + 1)) self.assertEqual(lr_callback.wait, 0) lr_callback = lr_callback.update(6.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, 4.0) self.assertEqual(lr_callback.learning_rate, 0.25) self.assertEqual(lr_callback.wait, 0)
def test_cooldown(self): lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=2.0, decay_factor=0.5, minimize=False, window_size=1, patience=0, cooldown=3) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.learning_rate, 2.0) self.assertEqual(lr_callback.cooldown, 3) self.assertEqual(lr_callback.cooldown_counter, 3) for i in range(2): lr_callback = lr_callback.update(-1.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.learning_rate, 2.0) self.assertEqual(lr_callback.wait, 0) self.assertEqual(lr_callback.cooldown, 3) self.assertEqual(lr_callback.cooldown_counter, 2 - i) lr_callback = lr_callback.update(-1.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.learning_rate, 1.0) self.assertEqual(lr_callback.wait, 0) self.assertEqual(lr_callback.cooldown, 3) self.assertEqual(lr_callback.cooldown_counter, 3)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn) task_spec = training_specs.TaskSpec( iterative_process_builder=iterative_process_builder, client_epochs_per_round=FLAGS.client_epochs_per_round, client_batch_size=FLAGS.client_batch_size, clients_per_round=FLAGS.clients_per_round, client_datasets_random_seed=FLAGS.client_datasets_random_seed) if FLAGS.task == 'cifar100': runner_spec = federated_cifar100.configure_training( task_spec, crop_size=FLAGS.cifar100_crop_size) elif FLAGS.task == 'emnist_cr': runner_spec = federated_emnist.configure_training( task_spec, model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': runner_spec = federated_emnist_ae.configure_training(task_spec) elif FLAGS.task == 'shakespeare': runner_spec = federated_shakespeare.configure_training( task_spec, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': runner_spec = federated_stackoverflow.configure_training( task_spec, vocab_size=FLAGS.so_nwp_vocab_size, num_oov_buckets=FLAGS.so_nwp_num_oov_buckets, sequence_length=FLAGS.so_nwp_sequence_length, max_elements_per_user=FLAGS.so_nwp_max_elements_per_user, num_validation_examples=FLAGS.so_nwp_num_validation_examples) elif FLAGS.task == 'stackoverflow_lr': runner_spec = federated_stackoverflow_lr.configure_training( task_spec, vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size, vocab_tags_size=FLAGS.so_lr_vocab_tags_size, max_elements_per_user=FLAGS.so_lr_max_elements_per_user, num_validation_examples=FLAGS.so_lr_num_validation_examples) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS)) training_loop.run(iterative_process=runner_spec.iterative_process, client_datasets_fn=runner_spec.client_datasets_fn, validation_fn=runner_spec.validation_fn, test_fn=runner_spec.test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder if FLAGS.task == 'cifar100': hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size federated_cifar100.run_federated( **shared_args, crop_size=FLAGS.cifar100_crop_size, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated( **shared_args, model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **shared_args, sequence_length=FLAGS.shakespeare_sequence_length, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated( **shared_args, **so_nwp_flags, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated( **shared_args, **so_lr_flags, hparam_dict=hparam_dict)