def test_fed_sgd_without_decay_decreases_loss(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.0, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state, train_outputs = self._run_rounds(iterative_process, 5) self.assertLess(train_outputs[-1]['before_training']['loss'], train_outputs[0]['before_training']['loss']) self.assertLess(train_outputs[-1]['during_training']['loss'], train_outputs[0]['during_training']['loss']) self.assertNear(state.client_lr_callback.learning_rate, 0.0, 1e-8) self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
def test_comparable_to_fed_avg(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) reference_iterative_process = tff.learning.build_federated_averaging_process( _uncompiled_model_builder, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0)) _, train_outputs = self._run_rounds(iterative_process, 5) _, reference_train_outputs = self._run_rounds_tff_fedavg( reference_iterative_process, 5) for i in range(5): self.assertAllClose(train_outputs[i]['during_training']['loss'], reference_train_outputs[i]['loss'], 1e-4)
def test_iterative_process_type_signature(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def test_raises_bad_decay_factor(self): with self.assertRaises(ValueError): callbacks.create_reduce_lr_on_plateau(learning_rate=0.1, decay_factor=2.0, cooldown=0) with self.assertRaises(ValueError): callbacks.create_reduce_lr_on_plateau(learning_rate=0.1, decay_factor=-1.0)
def test_cooldown(self): lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=2.0, decay_factor=0.5, minimize=False, window_size=1, patience=0, cooldown=3) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.learning_rate, 2.0) self.assertEqual(lr_callback.cooldown, 3) self.assertEqual(lr_callback.cooldown_counter, 3) for i in range(2): lr_callback = callbacks.update_reduce_lr_on_plateau( lr_callback, -1.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.learning_rate, 2.0) self.assertEqual(lr_callback.wait, 0) self.assertEqual(lr_callback.cooldown, 3) self.assertEqual(lr_callback.cooldown_counter, 2 - i) lr_callback = callbacks.update_reduce_lr_on_plateau(lr_callback, -1) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.learning_rate, 1.0) self.assertEqual(lr_callback.wait, 0) self.assertEqual(lr_callback.cooldown, 3) self.assertEqual(lr_callback.cooldown_counter, 3)
def test_small_lr_comparable_zero_lr(self): client_lr_callback1 = callbacks.create_reduce_lr_on_plateau( learning_rate=0.0, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) client_lr_callback2 = callbacks.create_reduce_lr_on_plateau( learning_rate=1e-8, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process1 = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback1, callbacks.update_reduce_lr_on_plateau, server_lr_callback, callbacks.update_reduce_lr_on_plateau, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterative_process2 = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback2, callbacks.update_reduce_lr_on_plateau, server_lr_callback, callbacks.update_reduce_lr_on_plateau, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state1, train_outputs1 = self._run_rounds(iterative_process1, 5) state2, train_outputs2 = self._run_rounds(iterative_process2, 5) self.assertAllClose(state1.model.trainable, state2.model.trainable, 1e-4) self.assertAllClose(train_outputs1, train_outputs2, 1e-4)
def test_iterative_process_type_signature(self): iterative_process = decay_iterative_process_builder.from_flags( input_spec=get_input_spec(), model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) dummy_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) lr_callback_type = tff.framework.type_from_tensors(dummy_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict( mean_squared_error=tff.TensorType(tf.float32), loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=(server_state_type, dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertTrue(actual_type.is_equivalent_to(expected_type))
def test_min_lr(self): lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=0.1, decay_factor=0.5, min_lr=0.2, minimize=False, window_size=1, patience=1, cooldown=0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.learning_rate, 0.2) for i in range(5): x = -float(i) lr_callback = lr_callback.update(x) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, 0.0) self.assertEqual(lr_callback.learning_rate, 0.2) self.assertEqual(lr_callback.wait, i + 1)
def test_lr_decay_after_patience_rounds(self): lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=1.0, decay_factor=0.5, minimize=False, window_size=3, patience=5, cooldown=0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.metrics_window, [0.0, 0.0, 0.0]) for i in range(4): lr_callback = lr_callback.update(-1.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, 0.0) self.assertEqual(lr_callback.learning_rate, 1.0) self.assertEqual(lr_callback.wait, i + 1) lr_callback = lr_callback.update(-1.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, 0.0) self.assertEqual(lr_callback.learning_rate, 0.5) self.assertEqual(lr_callback.wait, 0)
def test_window_with_inf_values(self): lr_callback = callbacks.create_reduce_lr_on_plateau(learning_rate=1.0, decay_factor=0.5, minimize=True, window_size=3, patience=1, cooldown=0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.metrics_window, [np.Inf for _ in range(3)]) for i in range(2): lr_callback = lr_callback.update(3.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, np.Inf) self.assertEqual(lr_callback.learning_rate, (0.5)**(i + 1)) self.assertEqual(lr_callback.wait, 0) lr_callback = lr_callback.update(6.0) logging.info('LR Callback: %s', lr_callback) self.assertEqual(lr_callback.best, 4.0) self.assertEqual(lr_callback.learning_rate, 0.25) self.assertEqual(lr_callback.wait, 0)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) shared_args = utils_impl.lookup_flag_values(shared_flags) shared_args['iterative_process_builder'] = iterative_process_builder if FLAGS.task == 'cifar100': hparam_dict['cifar100_crop_size'] = FLAGS.cifar100_crop_size federated_cifar100.run_federated(**shared_args, crop_size=FLAGS.cifar100_crop_size, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated(**shared_args, emnist_model=FLAGS.emnist_cr_model, hparam_dict=hparam_dict) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**shared_args, hparam_dict=hparam_dict) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **shared_args, sequence_length=FLAGS.shakespeare_sequence_length, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated(**shared_args, **so_nwp_flags, hparam_dict=hparam_dict) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in task_flags: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated(**shared_args, **so_lr_flags, hparam_dict=hparam_dict)
def test_build_with_preprocess_funtion(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return collections.OrderedDict(x=[float(x) * 1.0], y=[float(x) * 3.0 + 1.0]) return ds.map(to_batch).repeat().batch(2).take(3) client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, dataset_preprocess_comp=preprocess_dataset) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def from_flags(input_spec, model_builder, loss_builder, metrics_builder, client_weight_fn=None): """Builds a `tff.templates.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: input_spec: A value convertible to a `tff.Type`, representing the data which will be fed into the `tff.templates.IterativeProcess.next` function over the course of training. Generally, this can be found by accessing the `element_spec` attribute of a client `tf.data.Dataset`. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. Returns: A `tff.templates.IterativeProcess` instance. """ client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') def tff_model_fn(): return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) return adaptive_fed_avg.build_fed_avg_process( tff_model_fn, client_lr_callback, callbacks.update_reduce_lr_on_plateau, server_lr_callback, callbacks.update_reduce_lr_on_plateau, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, dataset_preprocess_comp=dataset_preprocess_comp) assign_weights_fn = adaptive_fed_avg.ServerState.assign_weights_to_keras_model common_args = collections.OrderedDict([ ('iterative_process_builder', iterative_process_builder), ('assign_weights_fn', assign_weights_fn), ('client_epochs_per_round', FLAGS.client_epochs_per_round), ('client_batch_size', FLAGS.client_batch_size), ('clients_per_round', FLAGS.clients_per_round), ('max_batches_per_client', FLAGS.max_batches_per_client), ('client_datasets_random_seed', FLAGS.client_datasets_random_seed) ]) if FLAGS.task == 'cifar100': federated_cifar100.run_federated(**common_args, crop_size=FLAGS.cifar100_crop_size) elif FLAGS.task == 'emnist_cr': federated_emnist.run_federated(**common_args, emnist_model=FLAGS.emnist_cr_model) elif FLAGS.task == 'emnist_ae': federated_emnist_ae.run_federated(**common_args) elif FLAGS.task == 'shakespeare': federated_shakespeare.run_federated( **common_args, sequence_length=FLAGS.shakespeare_sequence_length) elif FLAGS.task == 'stackoverflow_nwp': so_nwp_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_nwp_'): so_nwp_flags[flag_name[7:]] = FLAGS[flag_name].value federated_stackoverflow.run_federated(**common_args, **so_nwp_flags) elif FLAGS.task == 'stackoverflow_lr': so_lr_flags = collections.OrderedDict() for flag_name in FLAGS: if flag_name.startswith('so_lr_'): so_lr_flags[flag_name[6:]] = FLAGS[flag_name].value federated_stackoverflow_lr.run_federated(**common_args, **so_lr_flags) else: raise ValueError( '--task flag {} is not supported, must be one of {}.'.format( FLAGS.task, _SUPPORTED_TASKS))