def test_comparable_to_fed_avg(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) reference_iterative_process = tff.learning.build_federated_averaging_process( _uncompiled_model_builder, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0)) _, train_outputs = self._run_rounds(iterative_process, 5) _, reference_train_outputs = self._run_rounds_tff_fedavg( reference_iterative_process, 5) for i in range(5): self.assertAllClose(train_outputs[i]['during_training']['loss'], reference_train_outputs[i]['loss'], 1e-4)
def test_fed_sgd_without_decay_decreases_loss(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.0, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state, train_outputs = self._run_rounds(iterative_process, 5) self.assertLess(train_outputs[-1]['before_training']['loss'], train_outputs[0]['before_training']['loss']) self.assertLess(train_outputs[-1]['during_training']['loss'], train_outputs[0]['during_training']['loss']) self.assertNear(state.client_lr_callback.learning_rate, 0.0, 1e-8) self.assertNear(state.server_lr_callback.learning_rate, 0.1, 1e-8)
def test_iterative_process_type_signature(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def test_small_lr_comparable_zero_lr(self): client_lr_callback1 = callbacks.create_reduce_lr_on_plateau( learning_rate=0.0, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) client_lr_callback2 = callbacks.create_reduce_lr_on_plateau( learning_rate=1e-8, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process1 = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback1, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterative_process2 = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback2, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state1, train_outputs1 = self._run_rounds(iterative_process1, 5) state2, train_outputs2 = self._run_rounds(iterative_process2, 5) self.assertAllClose(state1.model.trainable, state2.model.trainable, 1e-4) self.assertAllClose(train_outputs1, train_outputs2, 1e-4)
def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn)
def test_get_model_weights(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, window_size=1, patience=1, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, window_size=1, patience=1, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state = iterative_process.initialize() self.assertIsInstance(iterative_process.get_model_weights(state), tff.learning.ModelWeights) self.assertAllClose( state.model.trainable, iterative_process.get_model_weights(state).trainable) state, _ = self._run_rounds(iterative_process, 5) self.assertIsInstance(iterative_process.get_model_weights(state), tff.learning.ModelWeights) self.assertAllClose( state.model.trainable, iterative_process.get_model_weights(state).trainable)
def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return adaptive_fed_avg.build_fed_avg_process( model_fn, client_lr_callback, server_lr_callback, client_optimizer_fn=client_optimizer_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn)