def build_federated_evaluation(model_fn): """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_utils.enhance(model_fn()) model_weights_type = tff.to_type( tf.nest.map_structure( lambda v: tff.TensorType(v.dtype.base_dtype, v.shape), model.weights)) batch_type = tff.to_type(model.input_spec) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" model = model_utils.enhance(model_fn()) # TODO(b/124477598): Remove dummy when b/121400757 has been fixed. @tf.function def reduce_fn(dummy, batch): model_output = model.forward_pass(batch, training=False) return dummy + tf.cast(model_output.loss, tf.float64) # TODO(b/123898430): The control dependencies below have been inserted as a # temporary workaround. These control dependencies need to be removed, and # defuns and datasets supported together fully. with tf.control_dependencies( [tff.utils.assign(model.weights, incoming_model_weights)]): dummy = dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn) with tf.control_dependencies([dummy]): return collections.OrderedDict([ ('local_outputs', model.report_local_outputs()), ('workaround for b/121400757', dummy) ]) @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs) return server_eval
def build_federated_evaluation(model_fn): """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_utils.enhance(model_fn()) model_weights_type = tff.to_type( tf.nest.map_structure( lambda v: tff.TensorType(v.dtype.base_dtype, v.shape), model.weights)) batch_type = tff.to_type(model.input_spec) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" model = model_utils.enhance(model_fn()) @tf.function def _tf_client_eval(incoming_model_weights, dataset): """Evaluation TF work.""" tff.utils.assign(model.weights, incoming_model_weights) def reduce_fn(prev_loss, batch): model_output = model.forward_pass(batch, training=False) return prev_loss + tf.cast(model_output.loss, tf.float64) dataset.reduce(tf.constant(0.0, dtype=tf.float64), reduce_fn) return collections.OrderedDict([('local_outputs', model.report_local_outputs())]) return _tf_client_eval(incoming_model_weights, dataset) @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(tff.SequenceType(batch_type), tff.CLIENTS)) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_eval, [tff.federated_broadcast(server_model_weights), federated_dataset]) return model.federated_output_computation(client_outputs.local_outputs) return server_eval
def test_success_with_valid_context(self): def model_fn(): return model_examples.LinearRegression(feature_dim=2) zero_model_weights = _create_zero_model_weights(model_fn) p13n_fn_dict = _create_p13n_fn_dict(learning_rate=1.0) # Build the p13n eval with an extra `context` argument. context_tff_type = tff.to_type(tf.int32) federated_p13n_eval = p13n_eval.build_personalization_eval( model_fn, p13n_fn_dict, _evaluate_fn, context_tff_type=context_tff_type) # Perform p13n eval on two clients with different `context` values. results = federated_p13n_eval(zero_model_weights, [ _create_client_input(train_scale=1.0, test_scale=1.0, context=2), _create_client_input(train_scale=1.0, test_scale=2.0, context=5) ]) results = results._asdict(recursive=True) bs1_metrics = results['batch_size_1'] bs2_metrics = results['batch_size_2'] # Number of training examples is `3 + context` for both clients. # Note: the order is not preserved due to `federated_sample`, but the order # should be consistent across different personalization strategies. self.assertAllEqual(sorted(bs1_metrics['num_examples']), [5, 8]) self.assertAllEqual(bs1_metrics['num_examples'], bs2_metrics['num_examples'])
def test_failure_with_invalid_context_type(self): def model_fn(): return model_examples.LinearRegression(feature_dim=2) zero_model_weights = _create_zero_model_weights(model_fn) p13n_fn_dict = _create_p13n_fn_dict(learning_rate=1.0) with self.assertRaises(TypeError): # `tf.int32` is not a `tff.Type`. bad_context_tff_type = tf.int32 federated_p13n_eval = p13n_eval.build_personalization_eval( model_fn, p13n_fn_dict, _evaluate_fn, context_tff_type=bad_context_tff_type) with self.assertRaises(TypeError): # `context_tff_type` is provided but `context` is not provided. context_tff_type = tff.to_type(tf.int32) federated_p13n_eval = p13n_eval.build_personalization_eval( model_fn, p13n_fn_dict, _evaluate_fn, context_tff_type=context_tff_type) federated_p13n_eval(zero_model_weights, [ _create_client_input(train_scale=1.0, test_scale=1.0, context=None), _create_client_input(train_scale=1.0, test_scale=2.0, context=None) ])
def build_personalization_eval(model_fn, personalize_fn_dict, baseline_evaluate_fn, max_num_samples=100, context_tff_type=None): """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from SERVER to CLIENTS. Each client evaluates the personalization strategies given in `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples` participating clients are collected to the SERVER. Args: model_fn: A no-argument function that returns a `tff.learning.Model`. personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a strategy name) to a no-argument function that returns a `tf.function`. Each `tf.function` represents a personalization strategy: it accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), a training `tf.dataset.Dataset`, a test `tf.dataset.Dataset`, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and a `tf.dataset.Dataset`, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are usually represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This function is *only* used to compute the baseline metrics of the initial model. max_num_samples: A positive `int` specifying the maximum number of metric samples to collect in a round. Each sample contains the personalization metrics from a single client. If the number of participating clients in a round is smaller than this value, all clients' metrics are collected. context_tff_type: A `tff.Type` of the optional context object used by the personalization strategies defined in `personalization_fn_dict`. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in `personalization_fn_dict`, its `tff.Type` must be provided here. Returns: A federated `tff.Computation` that maps < model_weights@SERVER, input@CLIENTS > -> personalization_metrics@SERVER, where: - model_weights is a `tff.learning.framework.ModelWeights`. - each client's input is an `OrderedDict` of at least two keys `train_data` and `test_data`, and each key is mapped to a `tf.dataset.Dataset`. If context is used in `personalize_fn_dict`, then client input has a third key `context` that is mapped to a object whose `tff.Type` is provided by the `context_tff_type` argument. - personazliation_metrics is an `OrderedDict` that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. - Note: only metrics from at most `max_num_samples` participating clients are collected to the SERVER. All collected metrics are stored in a single `OrderedDict` (the personalization_metrics shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. ValueError: If `max_num_samples` is not positive. """ # Obtain the types by constructing the model first. # TODO(b/124477628): Replace it with other ways of handling metadata. with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = tff.to_type(model.input_spec) # Define the `tff.Type` of each client's input. client_input_type = collections.OrderedDict([ ('train_data', tff.SequenceType(batch_type)), ('test_data', tff.SequenceType(batch_type)) ]) if context_tff_type is not None: py_typecheck.check_type(context_tff_type, tff.Type) client_input_type['context'] = context_tff_type client_input_type = tff.to_type(client_input_type) @tff.tf_computation(model_weights_type, client_input_type) def _client_computation(initial_model_weights, client_input): """TFF computation that runs on each client.""" model = model_fn() train_data = client_input['train_data'] test_data = client_input['test_data'] context = client_input.get('context', None) return _client_fn(model, initial_model_weights, train_data, test_data, personalize_fn_dict, baseline_evaluate_fn, context) py_typecheck.check_type(max_num_samples, int) if max_num_samples <= 0: raise ValueError('max_num_samples must be a positive integer.') @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(client_input_type, tff.CLIENTS)) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = tff.federated_broadcast(server_model_weights) client_final_metrics = tff.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. results = tff.utils.federated_sample(client_final_metrics, max_num_samples) return results return personalization_eval
def build_personalization_eval(model_fn, personalize_fn_dict, baseline_evaluate_fn, max_num_samples=100, context_tff_type=None): """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from `tff.SERVER` to `tff.CLIENTS`. Each client evaluates the personalization strategies given in `personalize_fn_dict`. Evaluation metrics from at most `max_num_samples` participating clients are collected to the server. NOTE: The functions in `personalize_fn_dict` and `baseline_evaluate_fn` are expected to take as input *unbatched* datasets, and are responsible for applying batching, if any, to the provided input datasets. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. personalize_fn_dict: An `OrderedDict` that maps a `string` (representing a strategy name) to a no-argument function that returns a `tf.function`. Each `tf.function` represents a personalization strategy: it accepts a `tff.learning.Model` (with weights already initialized to the given model weights when users invoke the returned TFF computation), an unbatched `tf.data.Dataset` for train, an unbatched `tf.data.Dataset` for test, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. baseline_evaluate_fn: A `tf.function` that accepts a `tff.learning.Model` (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and an unbatched `tf.data.Dataset`, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are represented as an `OrderedDict` (or a nested `OrderedDict`) of `string` metric names to scalar `tf.Tensor`s. This function is *only* used to compute the baseline metrics of the initial model. max_num_samples: A positive `int` specifying the maximum number of metric samples to collect in a round. Each sample contains the personalization metrics from a single client. If the number of participating clients in a round is smaller than this value, all clients' metrics are collected. context_tff_type: A `tff.Type` of the optional context object used by the personalization strategies defined in `personalization_fn_dict`. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in `personalization_fn_dict`, its `tff.Type` must be provided here. Returns: A federated `tff.Computation` with the functional type signature `(<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER)`: * `model_weights` is a `tff.learning.ModelWeights`. * Each client's input is an `OrderedDict` of two required keys `train_data` and `test_data`; each key is mapped to an unbatched `tf.data.Dataset`. If extra context (e.g., extra datasets) is used in `personalize_fn_dict`, then client input has a third key `context` that is mapped to a object whose `tff.Type` is provided by the `context_tff_type` argument. * `personazliation_metrics` is an `OrderedDict` that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by `baseline_evaluate_fn`), and maps keys (strategy names) in `personalize_fn_dict` to the evaluation metrics of the corresponding personalization strategies. * Note: only metrics from at most `max_num_samples` participating clients (sampled without replacement) are collected to the SERVER. All collected metrics are stored in a single `OrderedDict` (`personalization_metrics` shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client. Raises: TypeError: If arguments are of the wrong types. ValueError: If `baseline_metrics` is used as a key in `personalize_fn_dict`. ValueError: If `max_num_samples` is not positive. """ # Obtain the types by constructing the model first. # TODO(b/124477628): Replace it with other ways of handling metadata. with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_utils.enhance(model_fn()) model_weights_type = tff.framework.type_from_tensors(model.weights) batch_type = model.input_spec # Define the `tff.Type` of each client's input. Since batching (as well as # other preprocessing of datasets) is done within each personalization # strategy (i.e., by functions in `personalize_fn_dict`), the client-side # input should contain unbatched elements. element_type = _remove_batch_dim(batch_type) client_input_type = collections.OrderedDict([ ('train_data', tff.SequenceType(element_type)), ('test_data', tff.SequenceType(element_type)) ]) if context_tff_type is not None: py_typecheck.check_type(context_tff_type, tff.Type) client_input_type['context'] = context_tff_type client_input_type = tff.to_type(client_input_type) @tff.tf_computation(model_weights_type, client_input_type) def _client_computation(initial_model_weights, client_input): """TFF computation that runs on each client.""" train_data = client_input['train_data'] test_data = client_input['test_data'] context = client_input.get('context', None) final_metrics = collections.OrderedDict() # Compute the evaluation metrics of the initial model. final_metrics['baseline_metrics'] = _compute_baseline_metrics( model_fn, initial_model_weights, test_data, baseline_evaluate_fn) py_typecheck.check_type(personalize_fn_dict, collections.OrderedDict) if 'baseline_metrics' in personalize_fn_dict: raise ValueError('baseline_metrics should not be used as a key in ' 'personalize_fn_dict.') # Compute the evaluation metrics of the personalized models. The returned # `p13n_metrics` is an `OrderedDict` that maps keys (strategy names) in # `personalize_fn_dict` to the evaluation metrics of the corresponding # personalization strategies. p13n_metrics = _compute_p13n_metrics(model_fn, initial_model_weights, train_data, test_data, personalize_fn_dict, context) final_metrics.update(p13n_metrics) return final_metrics py_typecheck.check_type(max_num_samples, int) if max_num_samples <= 0: raise ValueError('max_num_samples must be a positive integer.') @tff.federated_computation( tff.FederatedType(model_weights_type, tff.SERVER), tff.FederatedType(client_input_type, tff.CLIENTS)) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" client_init_weights = tff.federated_broadcast(server_model_weights) client_final_metrics = tff.federated_map( _client_computation, (client_init_weights, federated_client_input)) # WARNING: Collecting information from clients can be risky. Users have to # make sure that it is proper to collect those metrics from clients. # TODO(b/147889283): Add a link to the TFF doc once it exists. results = tff.utils.federated_sample(client_final_metrics, max_num_samples) return results return personalization_eval