def __init__(self, model: model_lib.Model, batch_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, use_experimental_simulation_loop: bool = False): """Constructs the client computation for Federated SGD. Note: All variable creation required for the client computation (e.g. model variable construction) must occur in during construction, and not during `__call__`. Args: model: A `learning.Model` for which gradients are computed. batch_weight_fn: A function that takes a batch (as passed to forward_pass) and returns a float32 weight. If not provided, the default uses the size of the batch (as measured by the batch dimension of the predictions returned by forward_pass). use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. """ if batch_weight_fn is not None: py_typecheck.check_callable(batch_weight_fn) self._batch_weight_fn = batch_weight_fn self._model = model_utils.enhance(model) py_typecheck.check_type(self._model, model_utils.EnhancedModel) self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop)
def __init__(self, model: model_lib.Model, optimizer: tf.keras.optimizers.Optimizer, client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, use_experimental_simulation_loop: bool = False): """Creates the client computation for Federated Averaging. Note: All variable creation required for the client computation (e.g. model variable creation) must occur in during construction, and not during `__call__`. Args: model: A `tff.learning.Model` instance. optimizer: A `tf.keras.Optimizer` instance. client_weight_fn: an optional callable that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. """ py_typecheck.check_type(model, model_lib.Model) self._model = model_utils.enhance(model) self._optimizer = optimizer py_typecheck.check_type(self._model, model_utils.EnhancedModel) if client_weight_fn is not None: py_typecheck.check_callable(client_weight_fn) self._client_weight_fn = client_weight_fn else: self._client_weight_fn = None self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop)
def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" with tf.init_scope(): model = model_fn() model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda v, t: v.assign(t), model_weights, incoming_model_weights) def reduce_fn(num_examples, batch): model_output = model.forward_pass(batch, training=False) if model_output.num_examples is None: # Compute shape from the size of the predictions if model didn't use the # batch size. return num_examples + tf.shape(model_output.predictions, out_type=tf.int64)[0] else: return num_examples + tf.cast(model_output.num_examples, tf.int64) dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) num_examples = dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=lambda: tf.zeros([], dtype=tf.int64)) return collections.OrderedDict( local_outputs=model.report_local_outputs(), num_examples=num_examples)
def test_build_dataset_reduce_fn_float(self, simulation, reduce_fn): dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation) self.assertIs(dataset_reduce, reduce_fn) ds = tf.data.Dataset.range( 10, output_type=tf.float32).map(lambda x: 0.1 * x) total_sum = dataset_reduce_fn(reduce_fn=lambda x, y: x + y, dataset=ds) self.assertEqual(total_sum, np.float32(4.5))
def __init__( self, model: model_lib.Model, client_weighting: client_weight_lib.ClientWeightType = client_weight_lib .ClientWeighting.NUM_EXAMPLES, use_experimental_simulation_loop: bool = False): """Constructs the client computation for Federated SGD. Note: All variable creation required for the client computation (e.g. model variable construction) must occur in during construction, and not during `__call__`. Args: model: A `learning.Model` for which gradients are computed. client_weighting: A value of `tff.learning.ClientWeighting` that specifies a built-in weighting method, or a callable that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. """ client_weight_lib.check_is_client_weighting_or_callable(client_weighting) self._client_weighting = client_weighting self._model = model py_typecheck.check_type(self._model, model_lib.Model) self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop)
def test_build_dataset_reduce_fn_tuple(self, simulation, reduce_fn): dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation) self.assertIs(dataset_reduce, reduce_fn) ds = tf.data.Dataset.range( 10, output_type=tf.float32).map(lambda x: 0.1 * x) total_cnt, total_sum = dataset_reduce_fn( reduce_fn=lambda x, y: (x + 1, x + y), dataset=ds, initial_state_fn=lambda: (tf.constant(0), tf.constant(0.1))) self.assertEqual(total_cnt, np.float32(10)) self.assertEqual(total_sum, np.float32(4.6))
def test_dataset_reduce_op_presence(self, simulation): with tf.Graph().as_default() as graph: dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( simulation) ds = tf.data.Dataset.range(10, output_type=tf.int32) dataset_reduce_fn(reduce_fn=lambda x, y: x + y, dataset=ds) if simulation: self.assertIn(DATASET_REDUCE_OP, _get_op_names(graph.as_graph_def())) else: self.assertNotIn(DATASET_REDUCE_OP, _get_op_names(graph.as_graph_def()))
def _tf_client_eval(incoming_model_weights, dataset): """Evaluation TF work.""" tf_computation_utils.assign(model.weights, incoming_model_weights) def reduce_fn(prev_loss, batch): model_output = model.forward_pass(batch, training=False) return prev_loss + tf.cast(model_output.loss, tf.float64) dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=lambda: tf.constant(0, dtype=tf.float64)) return collections.OrderedDict([('local_outputs', model.report_local_outputs())])
def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" with tf.init_scope(): model = model_fn() model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda v, t: v.assign(t), model_weights, incoming_model_weights) def reduce_fn(prev_loss, batch): model_output = model.forward_pass(batch, training=False) return prev_loss + tf.cast(model_output.loss, tf.float64) dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=lambda: tf.constant(0, dtype=tf.float64)) return collections.OrderedDict( local_outputs=model.report_local_outputs())
def __init__(self, model: model_lib.Model, optimizer: tf.keras.optimizers.Optimizer, client_weighting: Union[ ClientWeighting, ClientWeightFnType] = ClientWeighting.NUM_EXAMPLES, use_experimental_simulation_loop: bool = False): """Creates the client computation for Federated Averaging. Note: All variable creation required for the client computation (e.g. model variable creation) must occur in during construction, and not during `__call__`. Args: model: A `tff.learning.Model` instance. optimizer: A `tf.keras.Optimizer` instance. client_weighting: A value of `tff.learning.ClientWeighting` that specifies a built-in weighting method, or a callable that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. """ py_typecheck.check_type(model, model_lib.Model) self._model = model_utils.enhance(model) self._optimizer = optimizer py_typecheck.check_type(self._model, model_utils.EnhancedModel) if (not isinstance(client_weighting, ClientWeighting) and not callable(client_weighting)): raise TypeError(f'`client_weighting` must be either instance of ' f'`ClientWeighting` or callable. ' f'Found type {type(client_weighting)}.') self._client_weighting = client_weighting self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop)
def __init__( self, model: model_lib.Model, optimizer: Union[optimizer_base.Optimizer, Callable[[], tf.keras.optimizers.Optimizer]], client_weighting: client_weight_lib. ClientWeightType = client_weight_lib.ClientWeighting.NUM_EXAMPLES, use_experimental_simulation_loop: bool = False): """Creates the client computation for Federated Averaging. Note: All variable creation required for the client computation (e.g. model variable creation) must occur in during construction, and not during `__call__`. Args: model: A `tff.learning.Model` instance. optimizer: A `optimizer_base.Optimizer` instance, or a no-arg callable that returns a `tf.keras.Optimizer` instance.. client_weighting: A value of `tff.learning.ClientWeighting` that specifies a built-in weighting method, or a callable that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. """ py_typecheck.check_type(model, model_lib.Model) self._model = model self._optimizer = keras_optimizer.build_or_verify_tff_optimizer( optimizer, model_utils.ModelWeights.from_model(self._model).trainable, disjoint_init_and_next=False) client_weight_lib.check_is_client_weighting_or_callable( client_weighting) self._client_weighting = client_weighting self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop)
def build_model_delta_update_with_tff_optimizer( model_fn: Callable[[], model_lib.Model], *, weighting: client_weight_lib.ClientWeighting, delta_l2_regularizer: float = 0.0, use_experimental_simulation_loop: bool = False): """Creates client update logic in FedAvg using a TFF optimizer. In contrast to using a `tf.keras.optimizers.Optimizer`, we avoid creating `tf.Variable`s associated with the optimizer state within the scope of the client work, as they are not necessary. This also means that the client's model weights are updated by computing `optimizer.next` and then assigning the result to the model weights (while a `tf.keras.optimizers.Optimizer` will modify the model weight in place using `optimizer.apply_gradients`). Args: model_fn: A no-arg callable returning a `tff.learning.Model`. weighting: A `tff.learning.ClientWeighting` value. delta_l2_regularizer: A nonnegative float, L2 regularization strength of the model delta. use_experimental_simulation_loop: Controls the reduce loop function for the input dataset. An experimental reduce loop is used for simulation. Returns: A `tf.function`. """ model = model_fn() dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) @tf.function def client_update(optimizer, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(state, batch): """Trains a `tff.learning.Model` on a batch of data.""" num_examples_sum, optimizer_state = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) if delta_l2_regularizer > 0.0: proximal_term = tf.nest.map_structure( lambda x, y: delta_l2_regularizer * (y - x), model_weights.trainable, initial_weights.trainable) gradients = tf.nest.map_structure(tf.add, gradients, proximal_term) optimizer_state, updated_weights = optimizer.next( optimizer_state, tuple(tf.nest.flatten(model_weights.trainable)), tuple(tf.nest.flatten(gradients))) updated_weights = tf.nest.pack_sequence_as(model_weights.trainable, updated_weights) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights.trainable, updated_weights) if output.num_examples is None: num_examples_sum += tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples_sum += tf.cast(output.num_examples, tf.int64) return num_examples_sum, optimizer_state def initial_state_for_reduce_fn(): # TODO(b/161529310): We flatten and convert the trainable specs to tuple, # as "for batch in data:" pattern would try to stack the tensors in list. trainable_tensor_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), tuple(tf.nest.flatten(model_weights.trainable))) return (tf.zeros(shape=[], dtype=tf.int64), optimizer.initialize(trainable_tensor_specs)) num_examples, _ = dataset_reduce_fn( reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn) client_update = tf.nest.map_structure(tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? client_update, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_update)) client_weight = _choose_client_weight(weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_update, update_weight=client_weight), model_output return client_update
def build_model_delta_update_with_keras_optimizer( model_fn, weighting, delta_l2_regularizer=0.0, use_experimental_simulation_loop: bool = False): """Creates client update logic in FedAvg using a `tf.keras` optimizer. In contrast to using a `tff.learning.optimizers.Optimizer`, we have to maintain `tf.Variable`s associated with the optimizer state within the scope of the client work. Additionally, the client model weights are modified in place by using `optimizer.apply_gradients`). Args: model_fn: A no-arg callable returning a `tff.learning.Model`. weighting: A `tff.learning.ClientWeighting` value. delta_l2_regularizer: A nonnegative float, L2 regularization strength of the model delta. use_experimental_simulation_loop: Controls the reduce loop function for the input dataset. An experimental reduce loop is used for simulation. Returns: A `tf.function`. """ model = model_fn() dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) @tf.function def client_update(optimizer, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(num_examples_sum, batch): """Trains a `tff.learning.Model` on a batch of data.""" with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) if delta_l2_regularizer > 0.0: proximal_term = tf.nest.map_structure( lambda x, y: delta_l2_regularizer * (y - x), model_weights.trainable, initial_weights.trainable) gradients = tf.nest.map_structure(tf.add, gradients, proximal_term) grads_and_vars = zip(gradients, model_weights.trainable) optimizer.apply_gradients(grads_and_vars) # TODO(b/199782787): Add a unit test for a model that does not compute # `num_examples` in its forward pass. if output.num_examples is None: num_examples_sum += tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples_sum += tf.cast(output.num_examples, tf.int64) return num_examples_sum def initial_state_for_reduce_fn(): return tf.zeros(shape=[], dtype=tf.int64) num_examples = dataset_reduce_fn( reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn) client_update = tf.nest.map_structure(tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into # tff.federated_mean? client_update, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_update)) client_weight = _choose_client_weight(weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_update, update_weight=client_weight), model_output return client_update
def client_update_fn(global_optimizer_state, initial_weights, data): model = model_fn() dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) weight_tensor_specs = type_conversions.type_to_tf_tensor_specs( model_utils.weights_type_from_model(model)) @tf.function def client_update(global_optimizer_state, initial_weights, data): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def full_gradient_reduce_fn(state, batch): """Sums individual gradients, to be later divided by num_examples.""" gradient_sum, num_examples_sum = state with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) if output.num_examples is None: num_examples = tf.shape(output.predictions, out_type=tf.int64)[0] else: num_examples = tf.cast(output.num_examples, tf.int64) # TODO(b/161529310): We flatten and convert to tuple, as tf.data # iterators would try to stack the tensors in list into a single tensor. gradients = tuple( tf.nest.flatten( tape.gradient(output.loss, model_weights.trainable))) gradient_sum = tf.nest.map_structure( lambda g_sum, g: g_sum + g * tf.cast( num_examples, g.dtype), gradient_sum, gradients) num_examples_sum += num_examples return gradient_sum, num_examples_sum def initial_state_for_full_gradient_reduce_fn(): initial_gradient_sum = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape, spec.dtype), tuple(tf.nest.flatten(weight_tensor_specs.trainable))) initial_num_examples_sum = tf.constant(0, tf.int64) return initial_gradient_sum, initial_num_examples_sum full_gradient, num_examples = dataset_reduce_fn( full_gradient_reduce_fn, data, initial_state_for_full_gradient_reduce_fn) # Compute the average gradient. full_gradient = tf.nest.map_structure( lambda g: tf.math.divide_no_nan( g, tf.cast(num_examples, g.dtype)), full_gradient) # Resets the local model variables, including metrics states, as we are # not interested in metrics based on the full gradient evaluation, only # from the subsequent training. model.reset_metrics() def train_reduce_fn(state, batch): with tf.GradientTape() as tape: output = model.forward_pass(batch, training=True) gradients = tape.gradient(output.loss, model_weights.trainable) # Mime Lite keeps optimizer state unchanged during local training. _, updated_weights = optimizer.next(global_optimizer_state, model_weights.trainable, gradients) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights.trainable, updated_weights) return state # Performs local training, updating `tf.Variable`s in `model_weights`. dataset_reduce_fn(train_reduce_fn, data, initial_state_fn=lambda: tf.zeros(shape=[0])) client_weights_delta = tf.nest.map_structure( tf.subtract, initial_weights.trainable, model_weights.trainable) model_output = model.report_local_unfinalized_metrics() # TODO(b/122071074): Consider moving this functionality into aggregation. client_weights_delta, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(client_weights_delta)) client_weight = _choose_client_weight(client_weighting, has_non_finite_delta, num_examples) return client_works.ClientResult( update=client_weights_delta, update_weight=client_weight), model_output, full_gradient return client_update(global_optimizer_state, initial_weights, data)
def test_build_dataset_reduce_fn(self, simulation, reduce_fn): dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation) self.assertIs(dataset_reduce_fn, reduce_fn) ds = tf.data.Dataset.range(10, output_type=tf.int32) total_sum = dataset_reduce_fn(reduce_fn=lambda x, y: x + y, dataset=ds) self.assertEqual(total_sum, np.int32(45))
from unittest import mock from absl.testing import parameterized import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model_examples from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning import personalization_eval as p13n_eval from tensorflow_federated.python.learning.framework import dataset_reduce # TODO(b/160896627): Switch to `dataset.reduce` once multi-GPU supports it. dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation_flag=True) @tf.function def _evaluate_fn(model, dataset, batch_size=1): """Evaluates a `tff.learning.Model` on the given dataset.""" # Reset the local variables so that the returned metrics are computed using # the given data. Similar to the `reset_states` method of `tf.metrics.Metric`. for var in model.local_variables: if var.initial_value is not None: var.assign(var.initial_value) else: var.assign(tf.zeros_like(var)) def eval_fn(whimsy_state, batch): """Evaluates the model on a batch."""
def _build_client_update(model: model_lib.Model, use_experimental_simulation_loop: bool = False): """Creates client update logic for FedSGD. Args: model: A `tff.learning.Model` used to compute gradients. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. Returns: A `tf.function`. """ dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) @tf.function def client_update(initial_weights, dataset): model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) def reduce_fn(state, batch): """Runs forward_pass on batch and sums the weighted gradients.""" accumulated_gradients, num_examples_sum = state with tf.GradientTape() as tape: output = model.forward_pass(batch) gradients = tape.gradient(output.loss, model_weights.trainable) num_examples = tf.cast(output.num_examples, tf.float32) accumulated_gradients = tuple( accumulator + num_examples * gradient for accumulator, gradient in zip(accumulated_gradients, gradients)) # We may be able to optimize the reduce function to avoid doubling the # number of required variables here (e.g. keeping two copies of all # gradients). If you're looking to optimize memory usage this might be a # place to look. return (accumulated_gradients, num_examples_sum + num_examples) def _zero_initial_state(): """Create a tuple of (gradient accumulators, num examples).""" return tuple( tf.nest.map_structure(tf.zeros_like, model_weights.trainable)), tf.constant( 0, dtype=tf.float32) gradient_sums, num_examples_sum = dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=_zero_initial_state) # We now normalize to compute the average gradient over all examples. average_gradient = tf.nest.map_structure( lambda gradient: gradient / num_examples_sum, gradient_sums) model_output = model.report_local_unfinalized_metrics() average_gradient, has_non_finite_delta = ( tensor_utils.zero_all_if_any_non_finite(average_gradient)) if has_non_finite_delta > 0: client_weight = tf.constant(0.0) else: client_weight = num_examples_sum return client_works.ClientResult( update=average_gradient, update_weight=client_weight), model_output return client_update