def iterative_process_builder( model_fn: Callable[[], tff.learning.Model] ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. Returns: A `tff.templates.IterativeProcess`. """ if FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp': def client_weight_fn(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) else: client_weight_fn = None return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn)
def test_execute_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(1) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_example(x): del x # Unused. return collections.OrderedDict(x=[3.0], y=[2.0]) return ds.map(to_example).batch(1) iterproc = fed_avg_schedule.build_fed_avg_process( model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=0.01, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) _, train_outputs, _ = self._run_rounds(iterproc, [test_dataset], 6) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss']) train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2][ 'loss'] train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5][ 'loss'] self.assertLess(train_gap_second_half, train_gap_first_half)
def iterative_process_builder(model_fn, client_weight_fn=None): return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=0.1, server_optimizer_fn=tf.keras.optimizers.SGD, server_lr=1.0, client_weight_fn=client_weight_fn)
def from_flags( input_spec, model_builder: ModelBuilder, loss_builder: LossBuilder, metrics_builder: MetricsBuilder, client_weight_fn: Optional[ClientWeightFn] = None, ) -> tff.templates.IterativeProcess: """Builds a `tff.templates.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: input_spec: A value convertible to a `tff.Type`, representing the data which will be fed into the `tff.templates.IterativeProcess.next` function over the course of training. Generally, this can be found by accessing the `element_spec` attribute of a client `tf.data.Dataset`. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. Returns: A `tff.templates.IterativeProcess`. """ # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras # model. client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') model_input_spec = input_spec def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=model_input_spec, loss=loss_builder(), metrics=metrics_builder()) return fed_avg_schedule.build_fed_avg_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn)
def test_fed_avg_without_schedule_decreases_loss(self): federated_data = [[_batch_fn()]] iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_server_update_with_nan_data_is_noop(self): federated_data = [[_batch_fn(has_nan=True)]] iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state, _, initial_state = self._run_rounds(iterproc, federated_data, 1) self.assertAllClose(state.model.trainable, initial_state.model.trainable, 1e-8) self.assertAllClose(state.model.non_trainable, initial_state.model.non_trainable, 1e-8)
def test_fed_avg_with_custom_client_weight_fn(self): federated_data = [[_batch_fn()]] def client_weight_fn(local_outputs): return 1.0 / (1.0 + local_outputs['loss'][-1]) iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, client_weight_fn=client_weight_fn) _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_fed_avg_with_client_and_server_schedules(self): federated_data = [[_batch_fn()]] iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=lambda x: 0.1 / (x + 1)**2, server_optimizer_fn=tf.keras.optimizers.SGD, server_lr=lambda x: 1.0 / (x + 1)**2) _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 6) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss']) train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2]['loss'] train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5]['loss'] self.assertLess(train_gap_second_half, train_gap_first_half)
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.type_at_clients( tff.SequenceType(test_dataset.element_spec)) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return _Batch( tf.fill(dims=(784,), value=float(x) * 2.0), tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0)) return ds.map(to_batch).batch(2) iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) with tf.Graph().as_default(): test_model_for_types = _uncompiled_model_builder() server_state_type = tff.FederatedType( fed_avg_schedule.ServerState( model=tff.framework.type_from_tensors( tff.learning.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64,), round_num=tf.float32), tff.SERVER) metrics_type = test_model_for_types.federated_output_computation.type_signature.result expected_parameter_type = collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type, ) expected_result_type = (server_state_type, metrics_type) expected_type = tff.FunctionType( parameter=expected_parameter_type, result=expected_result_type) self.assertTrue( iterproc.next.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format( s=iterproc.next.type_signature, t=expected_type))
def test_server_update_with_inf_weight_is_noop(self): federated_data = [create_dataset()] client_weight_fn = lambda x: np.inf iterproc = fed_avg_schedule.build_fed_avg_process( model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=0.01, server_optimizer_fn=tf.keras.optimizers.SGD, client_weight_fn=client_weight_fn) state, _, initial_state = self._run_rounds(iterproc, federated_data, 1) self.assertAllClose(state.model.trainable, initial_state.model.trainable, 1e-8) self.assertAllClose(state.model.non_trainable, initial_state.model.non_trainable, 1e-8)
def test_fed_avg_with_server_schedule(self): federated_data = [[_batch_fn()]] @tf.function def lr_schedule(x): return 1.0 if x < 1.5 else 0.0 iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, server_lr=lr_schedule) _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 4) self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss']) self.assertNear( train_outputs[2]['loss'], train_outputs[3]['loss'], err=1e-4)
def test_get_model_weights(self): federated_data = [create_dataset()] iterative_process = fed_avg_schedule.build_fed_avg_process( model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=0.01, server_optimizer_fn=tf.keras.optimizers.SGD) state = iterative_process.initialize() self.assertIsInstance(iterative_process.get_model_weights(state), tff.learning.ModelWeights) self.assertAllClose( state.model.trainable, iterative_process.get_model_weights(state).trainable) for _ in range(3): state, _ = iterative_process.next(state, federated_data) self.assertIsInstance(iterative_process.get_model_weights(state), tff.learning.ModelWeights) self.assertAllClose( state.model.trainable, iterative_process.get_model_weights(state).trainable)
def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn)