def test_fed_sgd_without_decay_decreases_loss(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.0, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state, train_outputs = self._run_rounds(iterative_process, 5) self.assertLess(train_outputs[-1]['before_training']['loss'], train_outputs[0]['before_training']['loss']) self.assertLess(train_outputs[-1]['during_training']['loss'], train_outputs[0]['during_training']['loss']) self.assertNear(state.client_lr_callback.learning_rate, 0.0, 1e-8) self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
def test_comparable_to_fed_avg(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) reference_iterative_process = tff.learning.build_federated_averaging_process( _uncompiled_model_builder, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0)) _, train_outputs = self._run_rounds(iterative_process, 5) _, reference_train_outputs = self._run_rounds_tff_fedavg( reference_iterative_process, 5) for i in range(5): self.assertAllClose(train_outputs[i]['during_training']['loss'], reference_train_outputs[i]['loss'], 1e-4)
def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, dataset_preprocess_comp=dataset_preprocess_comp)
def test_iterative_process_type_signature(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def test_small_lr_comparable_zero_lr(self): client_lr_callback1 = callbacks.create_reduce_lr_on_plateau( learning_rate=0.0, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) client_lr_callback2 = callbacks.create_reduce_lr_on_plateau( learning_rate=1e-8, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process1 = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback1, callbacks.update_reduce_lr_on_plateau, server_lr_callback, callbacks.update_reduce_lr_on_plateau, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterative_process2 = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback2, callbacks.update_reduce_lr_on_plateau, server_lr_callback, callbacks.update_reduce_lr_on_plateau, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state1, train_outputs1 = self._run_rounds(iterative_process1, 5) state2, train_outputs2 = self._run_rounds(iterative_process2, 5) self.assertAllClose(state1.model.trainable, state2.model.trainable, 1e-4) self.assertAllClose(train_outputs1, train_outputs2, 1e-4)
def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn)
def test_build_with_preprocess_funtion(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return collections.OrderedDict(x=[float(x) * 1.0], y=[float(x) * 3.0 + 1.0]) return ds.map(to_batch).repeat().batch(2).take(3) client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, dataset_preprocess_comp=preprocess_dataset) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def from_flags(input_spec, model_builder, loss_builder, metrics_builder, client_weight_fn=None): """Builds a `tff.templates.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: input_spec: A value convertible to a `tff.Type`, representing the data which will be fed into the `tff.templates.IterativeProcess.next` function over the course of training. Generally, this can be found by accessing the `element_spec` attribute of a client `tf.data.Dataset`. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. Returns: A `tff.templates.IterativeProcess` instance. """ client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.server_learning_rate, decay_factor=FLAGS.server_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') def tff_model_fn(): return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) return adaptive_fed_avg.build_fed_avg_process( tff_model_fn, client_lr_callback, callbacks.update_reduce_lr_on_plateau, server_lr_callback, callbacks.update_reduce_lr_on_plateau, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn)