def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn, dataset_preprocess_comp=dataset_preprocess_comp)
def test_execute_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(1) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_example(x): del x # Unused. return _Batch(x=np.ones([784], dtype=np.float32), y=np.ones([1], dtype=np.int64)) return ds.map(to_example).batch(1) iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, dataset_preprocess_comp=preprocess_dataset) _, train_outputs, _ = self._run_rounds(iterproc, [test_dataset], 6) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss']) train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2][ 'loss'] train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5][ 'loss'] self.assertLess(train_gap_second_half, train_gap_first_half)
def test_build_evaluate_fn(self): loss_builder = tf.keras.losses.MeanSquaredError metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] def tff_model_fn(): return tff.learning.from_keras_model( keras_model=model_builder(), dummy_batch=get_sample_batch(), loss=loss_builder(), metrics=metrics_builder()) iterative_process = fed_avg_schedule.build_fed_avg_process( tff_model_fn, client_optimizer_fn=tf.keras.optimizers.SGD) state = iterative_process.initialize() test_dataset = create_tf_dataset_for_client(1) evaluate_fn = training_utils.build_evaluate_fn(test_dataset, model_builder, loss_builder, metrics_builder) test_metrics = evaluate_fn(state) self.assertIn('loss', test_metrics)
def iterative_process_builder(model_fn, client_weight_fn=None): return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=0.1, server_optimizer_fn=tf.keras.optimizers.SGD, server_lr=1.0, client_weight_fn=client_weight_fn)
def from_flags( input_spec, model_builder: ModelBuilder, loss_builder: LossBuilder, metrics_builder: MetricsBuilder, client_weight_fn: Optional[ClientWeightFn] = None, ) -> tff.templates.IterativeProcess: """Builds a `tff.templates.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: input_spec: A value convertible to a `tff.Type`, representing the data which will be fed into the `tff.templates.IterativeProcess.next` function over the course of training. Generally, this can be found by accessing the `element_spec` attribute of a client `tf.data.Dataset`. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. Returns: A `tff.templates.IterativeProcess`. """ # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras # model. client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') model_input_spec = input_spec def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=model_input_spec, loss=loss_builder(), metrics=metrics_builder()) return fed_avg_schedule.build_fed_avg_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn)
def test_server_update_with_nan_data_is_noop(self): federated_data = [[_batch_fn(has_nan=True)]] iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state, _, initial_state = self._run_rounds(iterproc, federated_data, 1) self.assertAllClose(state.model, initial_state.model, 1e-8)
def test_fed_avg_without_schedule_decreases_loss(self): federated_data = [[_batch_fn()]] iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def from_flags(dummy_batch, model_builder, loss_builder, metrics_builder, client_weight_fn=None): """Builds a `tff.utils.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: dummy_batch: A nested structure of values that are convertible to batched tensors with the same shapes and types as expected in the forward pass of training. The actual values are not important and can hold any reasonable value. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. Returns: A `tff.utils.IterativeProcess` instance. """ # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras # model. client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags( 'server') def tff_model_fn(): return tff.learning.from_keras_model(keras_model=model_builder(), dummy_batch=dummy_batch, loss=loss_builder(), metrics=metrics_builder()) return fed_avg_schedule.build_fed_avg_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn)
def test_conversion_from_tff_result(self): federated_data = [[_batch_fn()]] iterative_process = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) state, _, _ = self._run_rounds(iterative_process, federated_data, 1) converted_state = fed_avg_schedule.ServerState.from_tff_result(state) self.assertIsInstance(converted_state, fed_avg_schedule.ServerState) self.assertIsInstance(converted_state.model, fed_avg_schedule.ModelWeights)
def test_server_update_with_inf_weight_is_noop(self): federated_data = [[_batch_fn()]] client_weight_fn = lambda x: np.inf iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, client_weight_fn=client_weight_fn) state, _, initial_state = self._run_rounds(iterproc, federated_data, 1) self.assertAllClose(state.model, initial_state.model, 1e-8)
def test_fed_avg_with_custom_client_weight_fn(self): federated_data = [[_batch_fn()]] def client_weight_fn(local_outputs): return 1.0 / (1.0 + local_outputs['loss'][-1]) iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, client_weight_fn=client_weight_fn) _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
def test_fed_avg_with_client_and_server_schedules(self): federated_data = [[_batch_fn()]] iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=lambda x: 0.1 / (x + 1)**2, server_optimizer_fn=tf.keras.optimizers.SGD, server_lr=lambda x: 1.0 / (x + 1)**2) _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 6) self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss']) train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2]['loss'] train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5]['loss'] self.assertLess(train_gap_second_half, train_gap_first_half)
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return _Batch( tf.fill(dims=(784,), value=float(x) * 2.0), tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0)) return ds.map(to_batch).batch(2) iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) with tf.Graph().as_default(): test_model_for_types = _uncompiled_model_builder() server_state_type = tff.FederatedType( fed_avg_schedule.ServerState( model=tff.framework.type_from_tensors( tff.learning.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64,), round_num=tf.float32), tff.SERVER) metrics_type = test_model_for_types.federated_output_computation.type_signature.result expected_parameter_type = collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type, ) expected_result_type = (server_state_type, metrics_type) expected_type = tff.FunctionType( parameter=expected_parameter_type, result=expected_result_type) self.assertTrue( iterproc.next.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format( s=iterproc.next.type_signature, t=expected_type))
def test_fed_avg_with_server_schedule(self): federated_data = [[_batch_fn()]] @tf.function def lr_schedule(x): return 1.0 if x < 1.5 else 0.0 iterproc = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, server_lr=lr_schedule) _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 4) self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss']) self.assertNear( train_outputs[2]['loss'], train_outputs[3]['loss'], err=1e-4)
def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: """Creates an iterative process using a given TFF `model_fn`. Args: model_fn: A no-arg function returning a `tff.learning.Model`. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor providing the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ return fed_avg_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn)
def _build_federated_averaging_process(): return fed_avg_schedule.build_fed_avg_process( _uncompiled_model_fn, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD)
def from_flags( input_spec, model_builder: ModelBuilder, loss_builder: LossBuilder, metrics_builder: MetricsBuilder, client_weight_fn: Optional[ClientWeightFn] = None, *, dataset_preprocess_comp: Optional[tff.Computation] = None, ) -> fed_avg_schedule.FederatedAveragingProcessAdapter: """Builds a `tff.templates.IterativeProcess` instance from flags. The iterative process is designed to incorporate learning rate schedules, which are configured via flags. Args: input_spec: A value convertible to a `tff.Type`, representing the data which will be fed into the `tff.templates.IterativeProcess.next` function over the course of training. Generally, this can be found by accessing the `element_spec` attribute of a client `tf.data.Dataset`. model_builder: A no-arg function that returns an uncompiled `tf.keras.Model` object. loss_builder: A no-arg function returning a `tf.keras.losses.Loss` object. metrics_builder: A no-arg function that returns a list of `tf.keras.metrics.Metric` objects. client_weight_fn: An optional callable that takes the result of `tff.learning.Model.report_local_outputs` from the model returned by `model_builder`, and returns a scalar client weight. If `None`, defaults to the number of examples processed over all batches. dataset_preprocess_comp: Optional `tff.Computation` that sets up a data pipeline on the clients. The computation must take a squence of values and return a sequence of values, or in TFF type shorthand `(U* -> V*)`. If `None`, no dataset preprocessing is applied. If specified, `input_spec` is optinal, as the necessary type signatures will taken from the computation. Returns: A `fed_avg_schedule.FederatedAveragingProcessAdapter`. """ # TODO(b/147808007): Assert that model_builder() returns an uncompiled keras # model. client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client') server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server') client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('client') server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags('server') if dataset_preprocess_comp is not None: if input_spec is not None: print('Specified both `dataset_preprocess_comp` and `input_spec` when ' 'only one is necessary. Ignoring `input_spec` and using type ' 'signature of `dataset_preprocess_comp`.') model_input_spec = dataset_preprocess_comp.type_signature.result.element else: model_input_spec = input_spec def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=model_input_spec, loss=loss_builder(), metrics=metrics_builder()) return fed_avg_schedule.build_fed_avg_process( model_fn=tff_model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, client_weight_fn=client_weight_fn, dataset_preprocess_comp=dataset_preprocess_comp)