Ejemplo n.º 1
0
  def __init__(self,
               model: model_lib.Model,
               batch_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
               use_experimental_simulation_loop: bool = False):
    """Constructs the client computation for Federated SGD.

    Note: All variable creation required for the client computation (e.g. model
    variable construction) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `learning.Model` for which gradients are computed.
      batch_weight_fn: A function that takes a batch (as passed to forward_pass)
        and returns a float32 weight. If not provided, the default uses the size
        of the batch (as measured by the batch dimension of the predictions
        returned by forward_pass).
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
    if batch_weight_fn is not None:
      py_typecheck.check_callable(batch_weight_fn)
    self._batch_weight_fn = batch_weight_fn

    self._model = model_utils.enhance(model)
    py_typecheck.check_type(self._model, model_utils.EnhancedModel)

    self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
        use_experimental_simulation_loop)
Ejemplo n.º 2
0
    def __init__(self,
                 model: model_lib.Model,
                 optimizer: tf.keras.optimizers.Optimizer,
                 client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
                 use_experimental_simulation_loop: bool = False):
        """Creates the client computation for Federated Averaging.

    Note: All variable creation required for the client computation (e.g. model
    variable creation) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `tff.learning.Model` instance.
      optimizer: A `tf.keras.Optimizer` instance.
      client_weight_fn: an optional callable that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
        py_typecheck.check_type(model, model_lib.Model)
        self._model = model_utils.enhance(model)
        self._optimizer = optimizer
        py_typecheck.check_type(self._model, model_utils.EnhancedModel)

        if client_weight_fn is not None:
            py_typecheck.check_callable(client_weight_fn)
            self._client_weight_fn = client_weight_fn
        else:
            self._client_weight_fn = None

        self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
Ejemplo n.º 3
0
    def client_eval(incoming_model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""
        with tf.init_scope():
            model = model_fn()
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                              incoming_model_weights)

        def reduce_fn(num_examples, batch):
            model_output = model.forward_pass(batch, training=False)
            if model_output.num_examples is None:
                # Compute shape from the size of the predictions if model didn't use the
                # batch size.
                return num_examples + tf.shape(model_output.predictions,
                                               out_type=tf.int64)[0]
            else:
                return num_examples + tf.cast(model_output.num_examples,
                                              tf.int64)

        dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
        num_examples = dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=lambda: tf.zeros([], dtype=tf.int64))
        return collections.OrderedDict(
            local_outputs=model.report_local_outputs(),
            num_examples=num_examples)
Ejemplo n.º 4
0
 def test_build_dataset_reduce_fn_float(self, simulation, reduce_fn):
     dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation)
     self.assertIs(dataset_reduce, reduce_fn)
     ds = tf.data.Dataset.range(
         10, output_type=tf.float32).map(lambda x: 0.1 * x)
     total_sum = dataset_reduce_fn(reduce_fn=lambda x, y: x + y, dataset=ds)
     self.assertEqual(total_sum, np.float32(4.5))
Ejemplo n.º 5
0
  def __init__(
      self,
      model: model_lib.Model,
      client_weighting: client_weight_lib.ClientWeightType = client_weight_lib
      .ClientWeighting.NUM_EXAMPLES,
      use_experimental_simulation_loop: bool = False):
    """Constructs the client computation for Federated SGD.

    Note: All variable creation required for the client computation (e.g. model
    variable construction) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `learning.Model` for which gradients are computed.
      client_weighting: A value of `tff.learning.ClientWeighting` that
        specifies a built-in weighting method, or a callable that takes the
        output of `model.report_local_outputs` and returns a tensor that
        provides the weight in the federated average of model deltas.
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
    client_weight_lib.check_is_client_weighting_or_callable(client_weighting)
    self._client_weighting = client_weighting
    self._model = model
    py_typecheck.check_type(self._model, model_lib.Model)

    self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
        use_experimental_simulation_loop)
Ejemplo n.º 6
0
 def test_build_dataset_reduce_fn_tuple(self, simulation, reduce_fn):
     dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation)
     self.assertIs(dataset_reduce, reduce_fn)
     ds = tf.data.Dataset.range(
         10, output_type=tf.float32).map(lambda x: 0.1 * x)
     total_cnt, total_sum = dataset_reduce_fn(
         reduce_fn=lambda x, y: (x + 1, x + y),
         dataset=ds,
         initial_state_fn=lambda: (tf.constant(0), tf.constant(0.1)))
     self.assertEqual(total_cnt, np.float32(10))
     self.assertEqual(total_sum, np.float32(4.6))
Ejemplo n.º 7
0
 def test_dataset_reduce_op_presence(self, simulation):
     with tf.Graph().as_default() as graph:
         dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
             simulation)
         ds = tf.data.Dataset.range(10, output_type=tf.int32)
         dataset_reduce_fn(reduce_fn=lambda x, y: x + y, dataset=ds)
     if simulation:
         self.assertIn(DATASET_REDUCE_OP,
                       _get_op_names(graph.as_graph_def()))
     else:
         self.assertNotIn(DATASET_REDUCE_OP,
                          _get_op_names(graph.as_graph_def()))
Ejemplo n.º 8
0
    def _tf_client_eval(incoming_model_weights, dataset):
      """Evaluation TF work."""
      tf_computation_utils.assign(model.weights, incoming_model_weights)

      def reduce_fn(prev_loss, batch):
        model_output = model.forward_pass(batch, training=False)
        return prev_loss + tf.cast(model_output.loss, tf.float64)

      dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
          use_experimental_simulation_loop)
      dataset_reduce_fn(
          reduce_fn=reduce_fn,
          dataset=dataset,
          initial_state_fn=lambda: tf.constant(0, dtype=tf.float64))

      return collections.OrderedDict([('local_outputs',
                                       model.report_local_outputs())])
Ejemplo n.º 9
0
    def client_eval(incoming_model_weights, dataset):
        """Returns local outputs after evaluting `model_weights` on `dataset`."""
        with tf.init_scope():
            model = model_fn()
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                              incoming_model_weights)

        def reduce_fn(prev_loss, batch):
            model_output = model.forward_pass(batch, training=False)
            return prev_loss + tf.cast(model_output.loss, tf.float64)

        dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
        dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=lambda: tf.constant(0, dtype=tf.float64))
        return collections.OrderedDict(
            local_outputs=model.report_local_outputs())
Ejemplo n.º 10
0
    def __init__(self,
                 model: model_lib.Model,
                 optimizer: tf.keras.optimizers.Optimizer,
                 client_weighting: Union[
                     ClientWeighting,
                     ClientWeightFnType] = ClientWeighting.NUM_EXAMPLES,
                 use_experimental_simulation_loop: bool = False):
        """Creates the client computation for Federated Averaging.

    Note: All variable creation required for the client computation (e.g. model
    variable creation) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `tff.learning.Model` instance.
      optimizer: A `tf.keras.Optimizer` instance.
      client_weighting: A value of `tff.learning.ClientWeighting` that
        specifies a built-in weighting method, or a callable that takes the
        output of `model.report_local_outputs` and returns a tensor that
        provides the weight in the federated average of model deltas.
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
        py_typecheck.check_type(model, model_lib.Model)
        self._model = model_utils.enhance(model)
        self._optimizer = optimizer
        py_typecheck.check_type(self._model, model_utils.EnhancedModel)

        if (not isinstance(client_weighting, ClientWeighting)
                and not callable(client_weighting)):
            raise TypeError(f'`client_weighting` must be either instance of '
                            f'`ClientWeighting` or callable. '
                            f'Found type {type(client_weighting)}.')
        self._client_weighting = client_weighting

        self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
    def __init__(
            self,
            model: model_lib.Model,
            optimizer: Union[optimizer_base.Optimizer,
                             Callable[[], tf.keras.optimizers.Optimizer]],
            client_weighting: client_weight_lib.
        ClientWeightType = client_weight_lib.ClientWeighting.NUM_EXAMPLES,
            use_experimental_simulation_loop: bool = False):
        """Creates the client computation for Federated Averaging.

    Note: All variable creation required for the client computation (e.g. model
    variable creation) must occur in during construction, and not during
    `__call__`.

    Args:
      model: A `tff.learning.Model` instance.
      optimizer: A `optimizer_base.Optimizer` instance, or a no-arg callable
        that returns a `tf.keras.Optimizer` instance..
      client_weighting: A value of `tff.learning.ClientWeighting` that
        specifies a built-in weighting method, or a callable that takes the
        output of `model.report_local_outputs` and returns a tensor that
        provides the weight in the federated average of model deltas.
      use_experimental_simulation_loop: Controls the reduce loop function for
        input dataset. An experimental reduce loop is used for simulation.
    """
        py_typecheck.check_type(model, model_lib.Model)
        self._model = model
        self._optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            optimizer,
            model_utils.ModelWeights.from_model(self._model).trainable,
            disjoint_init_and_next=False)
        client_weight_lib.check_is_client_weighting_or_callable(
            client_weighting)
        self._client_weighting = client_weighting
        self._dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
Ejemplo n.º 12
0
def build_model_delta_update_with_tff_optimizer(
        model_fn: Callable[[], model_lib.Model],
        *,
        weighting: client_weight_lib.ClientWeighting,
        delta_l2_regularizer: float = 0.0,
        use_experimental_simulation_loop: bool = False):
    """Creates client update logic in FedAvg using a TFF optimizer.

  In contrast to using a `tf.keras.optimizers.Optimizer`, we avoid creating
  `tf.Variable`s associated with the optimizer state within the scope of the
  client work, as they are not necessary. This also means that the client's
  model weights are updated by computing `optimizer.next` and then assigning
  the result to the model weights (while a `tf.keras.optimizers.Optimizer` will
  modify the model weight in place using `optimizer.apply_gradients`).

  Args:
    model_fn: A no-arg callable returning a `tff.learning.Model`.
    weighting: A `tff.learning.ClientWeighting` value.
    delta_l2_regularizer: A nonnegative float, L2 regularization strength of the
      model delta.
    use_experimental_simulation_loop: Controls the reduce loop function for the
      input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A `tf.function`.
  """
    model = model_fn()
    dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
        use_experimental_simulation_loop)

    @tf.function
    def client_update(optimizer, initial_weights, data):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(state, batch):
            """Trains a `tff.learning.Model` on a batch of data."""
            num_examples_sum, optimizer_state = state
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model_weights.trainable)
            if delta_l2_regularizer > 0.0:
                proximal_term = tf.nest.map_structure(
                    lambda x, y: delta_l2_regularizer * (y - x),
                    model_weights.trainable, initial_weights.trainable)
                gradients = tf.nest.map_structure(tf.add, gradients,
                                                  proximal_term)
            optimizer_state, updated_weights = optimizer.next(
                optimizer_state,
                tuple(tf.nest.flatten(model_weights.trainable)),
                tuple(tf.nest.flatten(gradients)))
            updated_weights = tf.nest.pack_sequence_as(model_weights.trainable,
                                                       updated_weights)
            tf.nest.map_structure(lambda a, b: a.assign(b),
                                  model_weights.trainable, updated_weights)

            if output.num_examples is None:
                num_examples_sum += tf.shape(output.predictions,
                                             out_type=tf.int64)[0]
            else:
                num_examples_sum += tf.cast(output.num_examples, tf.int64)

            return num_examples_sum, optimizer_state

        def initial_state_for_reduce_fn():
            # TODO(b/161529310): We flatten and convert the trainable specs to tuple,
            # as "for batch in data:" pattern would try to stack the tensors in list.
            trainable_tensor_specs = tf.nest.map_structure(
                lambda v: tf.TensorSpec(v.shape, v.dtype),
                tuple(tf.nest.flatten(model_weights.trainable)))
            return (tf.zeros(shape=[], dtype=tf.int64),
                    optimizer.initialize(trainable_tensor_specs))

        num_examples, _ = dataset_reduce_fn(
            reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn)
        client_update = tf.nest.map_structure(tf.subtract,
                                              initial_weights.trainable,
                                              model_weights.trainable)
        model_output = model.report_local_unfinalized_metrics()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        client_update, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(client_update))
        client_weight = _choose_client_weight(weighting, has_non_finite_delta,
                                              num_examples)

        return client_works.ClientResult(
            update=client_update, update_weight=client_weight), model_output

    return client_update
Ejemplo n.º 13
0
def build_model_delta_update_with_keras_optimizer(
        model_fn,
        weighting,
        delta_l2_regularizer=0.0,
        use_experimental_simulation_loop: bool = False):
    """Creates client update logic in FedAvg using a `tf.keras` optimizer.

  In contrast to using a `tff.learning.optimizers.Optimizer`, we have to
  maintain `tf.Variable`s associated with the optimizer state within the scope
  of the client work. Additionally, the client model weights are modified in
  place by using `optimizer.apply_gradients`).

  Args:
    model_fn: A no-arg callable returning a `tff.learning.Model`.
    weighting: A `tff.learning.ClientWeighting` value.
    delta_l2_regularizer: A nonnegative float, L2 regularization strength of the
      model delta.
    use_experimental_simulation_loop: Controls the reduce loop function for the
      input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A `tf.function`.
  """
    model = model_fn()
    dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
        use_experimental_simulation_loop)

    @tf.function
    def client_update(optimizer, initial_weights, data):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(num_examples_sum, batch):
            """Trains a `tff.learning.Model` on a batch of data."""
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model_weights.trainable)
            if delta_l2_regularizer > 0.0:
                proximal_term = tf.nest.map_structure(
                    lambda x, y: delta_l2_regularizer * (y - x),
                    model_weights.trainable, initial_weights.trainable)
                gradients = tf.nest.map_structure(tf.add, gradients,
                                                  proximal_term)
            grads_and_vars = zip(gradients, model_weights.trainable)
            optimizer.apply_gradients(grads_and_vars)

            # TODO(b/199782787): Add a unit test for a model that does not compute
            # `num_examples` in its forward pass.
            if output.num_examples is None:
                num_examples_sum += tf.shape(output.predictions,
                                             out_type=tf.int64)[0]
            else:
                num_examples_sum += tf.cast(output.num_examples, tf.int64)

            return num_examples_sum

        def initial_state_for_reduce_fn():
            return tf.zeros(shape=[], dtype=tf.int64)

        num_examples = dataset_reduce_fn(
            reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn)
        client_update = tf.nest.map_structure(tf.subtract,
                                              initial_weights.trainable,
                                              model_weights.trainable)
        model_output = model.report_local_unfinalized_metrics()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        client_update, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(client_update))
        client_weight = _choose_client_weight(weighting, has_non_finite_delta,
                                              num_examples)
        return client_works.ClientResult(
            update=client_update, update_weight=client_weight), model_output

    return client_update
Ejemplo n.º 14
0
    def client_update_fn(global_optimizer_state, initial_weights, data):
        model = model_fn()
        dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
            use_experimental_simulation_loop)
        weight_tensor_specs = type_conversions.type_to_tf_tensor_specs(
            model_utils.weights_type_from_model(model))

        @tf.function
        def client_update(global_optimizer_state, initial_weights, data):
            model_weights = model_utils.ModelWeights.from_model(model)
            tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                                  initial_weights)

            def full_gradient_reduce_fn(state, batch):
                """Sums individual gradients, to be later divided by num_examples."""
                gradient_sum, num_examples_sum = state
                with tf.GradientTape() as tape:
                    output = model.forward_pass(batch, training=True)
                if output.num_examples is None:
                    num_examples = tf.shape(output.predictions,
                                            out_type=tf.int64)[0]
                else:
                    num_examples = tf.cast(output.num_examples, tf.int64)
                # TODO(b/161529310): We flatten and convert to tuple, as tf.data
                # iterators would try to stack the tensors in list into a single tensor.
                gradients = tuple(
                    tf.nest.flatten(
                        tape.gradient(output.loss, model_weights.trainable)))
                gradient_sum = tf.nest.map_structure(
                    lambda g_sum, g: g_sum + g * tf.cast(
                        num_examples, g.dtype), gradient_sum, gradients)
                num_examples_sum += num_examples
                return gradient_sum, num_examples_sum

            def initial_state_for_full_gradient_reduce_fn():
                initial_gradient_sum = tf.nest.map_structure(
                    lambda spec: tf.zeros(spec.shape, spec.dtype),
                    tuple(tf.nest.flatten(weight_tensor_specs.trainable)))
                initial_num_examples_sum = tf.constant(0, tf.int64)
                return initial_gradient_sum, initial_num_examples_sum

            full_gradient, num_examples = dataset_reduce_fn(
                full_gradient_reduce_fn, data,
                initial_state_for_full_gradient_reduce_fn)
            # Compute the average gradient.
            full_gradient = tf.nest.map_structure(
                lambda g: tf.math.divide_no_nan(
                    g, tf.cast(num_examples, g.dtype)), full_gradient)

            # Resets the local model variables, including metrics states, as we are
            # not interested in metrics based on the full gradient evaluation, only
            # from the subsequent training.
            model.reset_metrics()

            def train_reduce_fn(state, batch):
                with tf.GradientTape() as tape:
                    output = model.forward_pass(batch, training=True)
                gradients = tape.gradient(output.loss, model_weights.trainable)
                # Mime Lite keeps optimizer state unchanged during local training.
                _, updated_weights = optimizer.next(global_optimizer_state,
                                                    model_weights.trainable,
                                                    gradients)
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      model_weights.trainable, updated_weights)
                return state

            # Performs local training, updating `tf.Variable`s in `model_weights`.
            dataset_reduce_fn(train_reduce_fn,
                              data,
                              initial_state_fn=lambda: tf.zeros(shape=[0]))

            client_weights_delta = tf.nest.map_structure(
                tf.subtract, initial_weights.trainable,
                model_weights.trainable)
            model_output = model.report_local_unfinalized_metrics()

            # TODO(b/122071074): Consider moving this functionality into aggregation.
            client_weights_delta, has_non_finite_delta = (
                tensor_utils.zero_all_if_any_non_finite(client_weights_delta))
            client_weight = _choose_client_weight(client_weighting,
                                                  has_non_finite_delta,
                                                  num_examples)
            return client_works.ClientResult(
                update=client_weights_delta,
                update_weight=client_weight), model_output, full_gradient

        return client_update(global_optimizer_state, initial_weights, data)
Ejemplo n.º 15
0
 def test_build_dataset_reduce_fn(self, simulation, reduce_fn):
     dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation)
     self.assertIs(dataset_reduce_fn, reduce_fn)
     ds = tf.data.Dataset.range(10, output_type=tf.int32)
     total_sum = dataset_reduce_fn(reduce_fn=lambda x, y: x + y, dataset=ds)
     self.assertEqual(total_sum, np.int32(45))
from unittest import mock

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.learning import keras_utils
from tensorflow_federated.python.learning import model_examples
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning import personalization_eval as p13n_eval
from tensorflow_federated.python.learning.framework import dataset_reduce

# TODO(b/160896627): Switch to `dataset.reduce` once multi-GPU supports it.
dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation_flag=True)


@tf.function
def _evaluate_fn(model, dataset, batch_size=1):
  """Evaluates a `tff.learning.Model` on the given dataset."""
  # Reset the local variables so that the returned metrics are computed using
  # the given data. Similar to the `reset_states` method of `tf.metrics.Metric`.
  for var in model.local_variables:
    if var.initial_value is not None:
      var.assign(var.initial_value)
    else:
      var.assign(tf.zeros_like(var))

  def eval_fn(whimsy_state, batch):
    """Evaluates the model on a batch."""
Ejemplo n.º 17
0
def _build_client_update(model: model_lib.Model,
                         use_experimental_simulation_loop: bool = False):
    """Creates client update logic for FedSGD.

  Args:
    model: A `tff.learning.Model` used to compute gradients.
    use_experimental_simulation_loop: Controls the reduce loop function for
      input dataset. An experimental reduce loop is used for simulation.

  Returns:
    A `tf.function`.
  """
    dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
        use_experimental_simulation_loop)

    @tf.function
    def client_update(initial_weights, dataset):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(state, batch):
            """Runs forward_pass on batch and sums the weighted gradients."""
            accumulated_gradients, num_examples_sum = state

            with tf.GradientTape() as tape:
                output = model.forward_pass(batch)
            gradients = tape.gradient(output.loss, model_weights.trainable)
            num_examples = tf.cast(output.num_examples, tf.float32)
            accumulated_gradients = tuple(
                accumulator + num_examples * gradient
                for accumulator, gradient in zip(accumulated_gradients,
                                                 gradients))

            # We may be able to optimize the reduce function to avoid doubling the
            # number of required variables here (e.g. keeping two copies of all
            # gradients). If you're looking to optimize memory usage this might be a
            # place to look.
            return (accumulated_gradients, num_examples_sum + num_examples)

        def _zero_initial_state():
            """Create a tuple of (gradient accumulators, num examples)."""
            return tuple(
                tf.nest.map_structure(tf.zeros_like,
                                      model_weights.trainable)), tf.constant(
                                          0, dtype=tf.float32)

        gradient_sums, num_examples_sum = dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=_zero_initial_state)

        # We now normalize to compute the average gradient over all examples.
        average_gradient = tf.nest.map_structure(
            lambda gradient: gradient / num_examples_sum, gradient_sums)

        model_output = model.report_local_unfinalized_metrics()
        average_gradient, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(average_gradient))
        if has_non_finite_delta > 0:
            client_weight = tf.constant(0.0)
        else:
            client_weight = num_examples_sum

        return client_works.ClientResult(
            update=average_gradient, update_weight=client_weight), model_output

    return client_update