예제 #1
0
    def client_update(optimizer, initial_weights, data):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(state, batch):
            """Trains a `tff.learning.Model` on a batch of data."""
            num_examples_sum, optimizer_state = state
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model_weights.trainable)
            if delta_l2_regularizer > 0.0:
                proximal_term = tf.nest.map_structure(
                    lambda x, y: delta_l2_regularizer * (y - x),
                    model_weights.trainable, initial_weights.trainable)
                gradients = tf.nest.map_structure(tf.add, gradients,
                                                  proximal_term)
            optimizer_state, updated_weights = optimizer.next(
                optimizer_state,
                tuple(tf.nest.flatten(model_weights.trainable)),
                tuple(tf.nest.flatten(gradients)))
            updated_weights = tf.nest.pack_sequence_as(model_weights.trainable,
                                                       updated_weights)
            tf.nest.map_structure(lambda a, b: a.assign(b),
                                  model_weights.trainable, updated_weights)

            if output.num_examples is None:
                num_examples_sum += tf.shape(output.predictions,
                                             out_type=tf.int64)[0]
            else:
                num_examples_sum += tf.cast(output.num_examples, tf.int64)

            return num_examples_sum, optimizer_state

        def initial_state_for_reduce_fn():
            # TODO(b/161529310): We flatten and convert the trainable specs to tuple,
            # as "for batch in data:" pattern would try to stack the tensors in list.
            trainable_tensor_specs = tf.nest.map_structure(
                lambda v: tf.TensorSpec(v.shape, v.dtype),
                tuple(tf.nest.flatten(model_weights.trainable)))
            return (tf.zeros(shape=[], dtype=tf.int64),
                    optimizer.initialize(trainable_tensor_specs))

        num_examples, _ = dataset_reduce_fn(
            reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn)
        client_update = tf.nest.map_structure(tf.subtract,
                                              initial_weights.trainable,
                                              model_weights.trainable)
        model_output = model.report_local_unfinalized_metrics()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        client_update, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(client_update))
        client_weight = _choose_client_weight(weighting, has_non_finite_delta,
                                              num_examples)

        return client_works.ClientResult(
            update=client_update, update_weight=client_weight), model_output
def server_update(model, server_optimizer, server_state, weights_delta):
    """Updates `server_state` based on `weights_delta`, increase the round number.

  Args:
    model: A `tff.learning.Model`.
    server_optimizer: A `tf.keras.optimizers.Optimizer`.
    server_state: A `ServerState`, the state to be updated.
    weights_delta: An update to the trainable variables of the model.

  Returns:
    An updated `ServerState`.
  """
    model_weights = _get_weights(model)
    tff.utils.assign(model_weights, server_state.model)
    # Server optimizer variables must be initialized prior to invoking this
    tff.utils.assign(server_optimizer.variables(),
                     server_state.optimizer_state)

    weights_delta, has_non_finite_weight = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    if has_non_finite_weight > 0:
        return server_state

    # Apply the update to the model. We must multiply weights_delta by -1.0 to
    # view it as a gradient that should be applied to the server_optimizer.
    grads_and_vars = [(-1.0 * x, v)
                      for x, v in zip(weights_delta, model_weights.trainable)]

    server_optimizer.apply_gradients(grads_and_vars)

    # Create a new state based on the updated model.
    return tff.utils.update_state(server_state,
                                  model=model_weights,
                                  optimizer_state=server_optimizer.variables(),
                                  round_num=server_state.round_num + 1.0)
예제 #3
0
def server_update(model, server_optimizer, server_optimizer_vars, server_state,
                  weights_delta, grads_norm):
    """Updates `server_state` based on `weights_delta`.

  Args:
    model: A `tff.learning.Model`.
    server_optimizer: A `tf.keras.optimizers.Optimizer`.
    server_optimizer_vars: A list of previous variables of server_optimzer.
    server_state: A `ServerState` namedtuple, the state to be updated.
    weights_delta: An update to the trainable variables of the model.
    grads_norm: Summation of the norm of gradients from clients.

  Returns:
    An updated `ServerState`.
  """
    model_weights = tff.learning.framework.ModelWeights.from_model(model)
    tf.nest.map_structure(lambda v, t: v.assign(t),
                          (model_weights, server_optimizer_vars),
                          (server_state.model, server_state.optimizer_state))

    # Zero out the weight if there are any non-finite values.
    weights_delta, _ = (tensor_utils.zero_all_if_any_non_finite(weights_delta))

    grads_and_vars = tf.nest.map_structure(
        lambda x, v: (-1.0 * x, v), tf.nest.flatten(weights_delta),
        tf.nest.flatten(model_weights.trainable))

    server_optimizer.update_grads_norm(
        tf.nest.flatten(model_weights.trainable), grads_norm)
    server_optimizer.apply_gradients(grads_and_vars, name='server_update')

    return tff.utils.update_state(server_state,
                                  model=model_weights,
                                  optimizer_state=server_optimizer_vars)
예제 #4
0
    def server_update(global_model, mean_model_delta, optimizer_state):
        """Updates the global model with the mean model update from clients."""
        with tf.init_scope():
            # Create a structure of variables that the server optimizer can update.
            model_variables = tf.nest.map_structure(
                lambda t: tf.Variable(initial_value=tf.zeros(t.shape, t.dtype)
                                      ), global_model)
            optimizer = keras_optimizer.build_or_verify_tff_optimizer(
                server_optimizer_fn,
                model_variables.trainable,
                disjoint_init_and_next=True)

        # Set the variables to the current global model, the optimizer will
        # update these variables.
        tf.nest.map_structure(lambda a, b: a.assign(b), model_variables,
                              global_model)
        # We might have a NaN value e.g. if all of the clients processed had no
        # data, so the denominator in the federated_mean is zero. If we see any
        # NaNs, zero out the whole update.
        # TODO(b/124538167): We should increment a server counter to
        # track the fact a non-finite weights_delta was encountered.
        finite_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
            mean_model_delta)
        # Update the global model variables with the delta as a pseudo-gradient.
        negative_weights_delta = tf.nest.map_structure(lambda w: -1.0 * w,
                                                       finite_weights_delta)
        optimizer_state, updated_weights = optimizer.next(
            optimizer_state, model_variables.trainable, negative_weights_delta)
        # Keras optimizers mutate model variables in with the `next` step above, so
        # we skip calling the assignment for those optimizers.
        if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
            tf.nest.map_structure(lambda a, b: a.assign(b),
                                  model_variables.trainable, updated_weights)
        return model_variables, optimizer_state
예제 #5
0
def client_computation(
        # Tensor/Dataset arguments that will be supplied by TFF:
        gen_inputs_ds: tf.data.Dataset,
        real_data_ds: tf.data.Dataset,
        from_server: FromServer,
        # Python arguments bound to be bound at TFF computation construction time:
        generator: tf.keras.Model,
        discriminator: tf.keras.Model,
        train_discriminator_fn) -> ClientOutput:
    """The computation to run on the client, training the discriminator.

  Args:
    gen_inputs_ds: A `tf.data.Dataset` of generator_inputs.
    real_data_ds: A `tf.data.Dataset` of data from the real distribution.
    from_server: A `FromServer` object, including the current model weights.
    generator:  The generator.
    discriminator: The discriminator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.

  Returns:
    A `ClientOutput` object.
  """
    tf.nest.map_structure(lambda a, b: a.assign(b), generator.weights,
                          from_server.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), discriminator.weights,
                          from_server.discriminator_weights)

    num_examples = tf.constant(0)
    gen_inputs_and_real_data = tf.data.Dataset.zip(
        (gen_inputs_ds, real_data_ds))
    for gen_inputs, real_data in gen_inputs_and_real_data:
        # It's possible that real_data and gen_inputs have different batch sizes.
        # For calculating the discriminator loss, it's desirable to have equal-sized
        # contributions from both the real and fake data. Also, it's necessary if
        # using the Wasserstein gradient penalty (where a difference is taken b/w
        # the real and fake data). So here we reduce to the min batch size. This
        # also ensures num_examples properly reflects the amount of data trained on.
        min_batch_size = tf.minimum(
            tf.shape(real_data)[0],
            tf.shape(gen_inputs)[0])
        real_data = real_data[0:min_batch_size]
        gen_inputs = gen_inputs[0:min_batch_size]
        num_examples += train_discriminator_fn(generator, discriminator,
                                               gen_inputs, real_data)

    weights_delta = tf.nest.map_structure(tf.subtract, discriminator.weights,
                                          from_server.discriminator_weights)
    weights_delta, has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    update_weight = tf.cast(num_examples, tf.float32)
    # Zero out the weight if there are any non-finite values.
    # TODO(b/122071074): federated_mean might not do the right thing if
    # all clients have zero weight.
    update_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                            lambda: update_weight, lambda: tf.constant(0.0))
    return ClientOutput(
        discriminator_weights_delta=weights_delta,
        update_weight=update_weight,
        counters={'num_discriminator_train_examples': num_examples})
예제 #6
0
 def server_update(global_model, mean_model_delta, optimizer_state):
     """Updates the global model with the mean model update from clients."""
     with tf.init_scope():
         model = model_fn()
         optimizer = server_optimizer_fn()
         # We must force variable creation for momentum and adaptive optimizers.
         _eagerly_create_optimizer_variables(model=model,
                                             optimizer=optimizer)
     model_variables = model_utils.ModelWeights.from_model(model)
     optimizer_variables = optimizer.variables()
     # Set the variables to the current global model, the optimizer will
     # update these variables.
     tf.nest.map_structure(lambda a, b: a.assign(b),
                           (model_variables, optimizer_variables),
                           (global_model, optimizer_state))
     # We might have a NaN value e.g. if all of the clients processed had no
     # data, so the denominator in the federated_mean is zero. If we see any
     # NaNs, zero out the whole update.
     # TODO(b/124538167): We should increment a server counter to
     # track the fact a non-finite weights_delta was encountered.
     finite_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
         mean_model_delta)
     # Update the global model variables with the delta as a pseudo-gradient.
     _apply_delta(optimizer=optimizer,
                  model=model,
                  delta=finite_weights_delta)
     return model_variables, optimizer_variables
예제 #7
0
    def __call__(self, dataset, initial_weights):
        # TODO(b/123898430): The control dependencies below have been inserted as a
        # temporary workaround. These control dependencies need to be removed, and
        # defuns and datasets supported together fully.
        model = self._model

        # TODO(b/113112108): Remove this temporary workaround and restore check for
        # `tf.data.Dataset` after subclassing the currently used custom data set
        # representation from it.
        if 'Dataset' not in str(type(dataset)):
            raise TypeError('Expected a data set, found {}.'.format(
                py_typecheck.type_string(type(dataset))))

        # TODO(b/120801384): We should initialize model.local_variables here.
        # Or, we may just need a convention that TFF initializes all variables
        # before invoking the TF function.

        # We must assign to a variable here in order to use control_dependencies.
        dummy_weights = nest.map_structure(tf.assign, model.weights,
                                           initial_weights)

        with tf.control_dependencies(list(dummy_weights.trainable.values())):

            def reduce_fn(dummy_state, batch):
                """Runs `tff.learning.Model.train_on_batch` on local client batch."""
                output = model.train_on_batch(batch)
                tf.assign_add(self._num_examples,
                              tf.shape(output.predictions)[0])
                return dummy_state

            # TODO(b/124477598): Remove dummy_output when b/121400757 fixed.
            dummy_output = dataset.reduce(initial_state=tf.constant(0.0),
                                          reduce_func=reduce_fn)

        with tf.control_dependencies([dummy_output]):
            weights_delta = nest.map_structure(tf.subtract,
                                               model.weights.trainable,
                                               initial_weights.trainable)
            aggregated_outputs = model.report_local_outputs()
            weights_delta_weight = self._client_weight_fn(aggregated_outputs)  # pylint:disable=not-callable

            # TODO(b/122071074): Consider moving this functionality into
            # tff.federated_average?
            weights_delta, has_non_finite_delta = (
                tensor_utils.zero_all_if_any_non_finite(weights_delta))
            weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                                           lambda: weights_delta_weight,
                                           lambda: tf.constant(0))

            return optimizer_utils.ClientOutput(
                weights_delta, weights_delta_weight, aggregated_outputs,
                tensor_utils.to_odict({
                    'num_examples':
                    self._num_examples.value(),
                    'has_non_finite_delta':
                    has_non_finite_delta,
                    'workaround for b/121400757':
                    dummy_output,
                }))
예제 #8
0
    def __call__(self, dataset, initial_weights):
        # TODO(b/113112108): Remove this temporary workaround and restore check for
        # `tf.data.Dataset` after subclassing the currently used custom data set
        # representation from it.
        if 'Dataset' not in str(type(dataset)):
            raise TypeError('Expected a data set, found {}.'.format(
                py_typecheck.type_string(type(dataset))))

        model = self._model
        dummy_weights = nest.map_structure(tf.assign, model.weights,
                                           initial_weights)

        def reduce_fn(accumulated_grads, batch):
            """Runs forward_pass on batch."""
            with tf.contrib.eager.GradientTape() as tape:
                output = model.forward_pass(batch)

            with tf.control_dependencies(list(output)):
                flat_vars = nest.flatten(model.weights.trainable)
                grads = nest.pack_sequence_as(
                    accumulated_grads, tape.gradient(output.loss, flat_vars))

                if self._batch_weight_fn is not None:
                    batch_weight = self._batch_weight_fn(batch)
                else:
                    batch_weight = tf.cast(
                        tf.shape(output.predictions)[0], tf.float32)

            tf.assign_add(self._batch_weight_sum, batch_weight)
            return nest.map_structure(
                lambda accumulator, grad: accumulator + batch_weight * grad,
                accumulated_grads, grads)

        with tf.control_dependencies(list(dummy_weights.trainable.values())):
            self._grad_sum_vars = dataset.reduce(
                initial_state=self._grad_sum_vars, reduce_func=reduce_fn)

        with tf.control_dependencies(
            [tf.identity(v) for v in self._grad_sum_vars.values()]):
            # For SGD, the delta is just the negative of the average gradient:
            weights_delta = nest.map_structure(
                lambda gradient: -1.0 * gradient / self._batch_weight_sum,
                self._grad_sum_vars)
            weights_delta, has_non_finite_delta = (
                tensor_utils.zero_all_if_any_non_finite(weights_delta))
            weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                                           lambda: self._batch_weight_sum,
                                           lambda: tf.constant(0.0))

            return optimizer_utils.ClientOutput(
                weights_delta, weights_delta_weight,
                model.report_local_outputs(),
                tensor_utils.to_odict({
                    'client_weight':
                    weights_delta_weight,
                    'has_non_finite_delta':
                    has_non_finite_delta,
                }))
    def client_update(model,
                      dataset,
                      initial_weights,
                      client_optimizer,
                      client_id,
                      client_weight_fn=None):
        """Updates client model.

    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      initial_weights: A `tff.learning.ModelWeights` from server.
      client_optimizer: A `tf.keras.optimizer.Optimizer` object.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.

    Returns:
      A 'ClientOutput`.
    """

        model_weights = _get_weights(model)
        tff.utils.assign(model_weights, initial_weights)

        num_examples = tf.constant(0, dtype=tf.int32)
        for batch in dataset:
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch)
            grads = tape.gradient(output.loss, model_weights.trainable)
            #grads = tf.nest.map_structure(lambda g: clip_ops.clip_by_norm(g,5.0), grads)
            grads, _ = tf.clip_by_global_norm(grads, 1.0)
            grads_and_vars = zip(grads, model_weights.trainable)
            client_optimizer.apply_gradients(grads_and_vars)
            num_examples += tf.shape(output.predictions)[0]

        aggregated_outputs = model.report_local_outputs()
        weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                              model_weights.trainable,
                                              initial_weights.trainable)
        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        if has_non_finite_weight > 0:
            client_weight = tf.constant([[0]], dtype=tf.float32)
        else:
            client_weight = tf.cast([[num_examples]], dtype=tf.float32)
        # else:
        # client_weight = client_weight_fn(aggregated_outputs)

        #weights_delta_encoded = tf.nest.map_structure(mean_encoder_fn, weights_delta)

        return ClientOutput(
            weights_delta, client_weight, aggregated_outputs,
            collections.OrderedDict([('num_examples', num_examples)]),
            client_id)
예제 #10
0
    def __call__(self, dataset, initial_weights):
        # N.B. When not in eager mode, this code must be wrapped as a defun
        # as it uses program-order semantics to avoid adding many explicit
        # control dependencies.
        model = self._model
        py_typecheck.check_type(dataset, tf.data.Dataset)

        nest.map_structure(tf.assign, model.weights, initial_weights)

        @tf.contrib.eager.function(autograph=False)
        def reduce_fn(dummy_state, batch):
            """Runs forward_pass on batch."""
            with tf.contrib.eager.GradientTape() as tape:
                output = model.forward_pass(batch)

            flat_vars = nest.flatten(model.weights.trainable)
            grads = nest.pack_sequence_as(
                self._grad_sum_vars, tape.gradient(output.loss, flat_vars))

            if self._batch_weight_fn is not None:
                batch_weight = self._batch_weight_fn(batch)
            else:
                batch_weight = tf.cast(
                    tf.shape(output.predictions)[0], tf.float32)

            tf.assign_add(self._batch_weight_sum, batch_weight)
            nest.map_structure(
                lambda v, g:  # pylint:disable=g-long-lambda
                tf.assign_add(v, batch_weight * g),
                self._grad_sum_vars,
                grads)

            return dummy_state

        # TODO(b/121400757): Remove dummy_output when bug fixed.
        dummy_output = dataset.reduce(initial_state=tf.constant(0.0),
                                      reduce_func=reduce_fn)

        # For SGD, the delta is just the negative of the average gradient:
        # TODO(b/109733734): Might be better to send the weighted grad sums
        # and the denominator separately?
        weights_delta = nest.map_structure(
            lambda g: -1.0 * g / self._batch_weight_sum, self._grad_sum_vars)
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        weights_delta_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                                       lambda: self._batch_weight_sum,
                                       lambda: tf.constant(0.0))

        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, model.report_local_outputs(),
            tensor_utils.to_odict({
                'client_weight': weights_delta_weight,
                'has_non_finite_delta': has_non_finite_delta,
                'workaround for b/121400757': dummy_output,
            }))
예제 #11
0
 def expect_zeros(structure, expected):
   with tf.Graph().as_default():
     result, error = tensor_utils.zero_all_if_any_non_finite(structure)
     with self.session() as sess:
       result, error = sess.run((result, error))
     try:
       tf.nest.map_structure(np.testing.assert_allclose, result, expected)
     except AssertionError:
       self.fail('Expected to get zeros, but instead got {}'.format(result))
     self.assertEqual(error, 1)
예제 #12
0
    def server_update(server_state, weights_delta, aggregator_state,
                      broadcaster_state):
        """Updates the `server_state` based on `weights_delta`.

    Args:
      server_state: A `tff.learning.framework.ServerState`, the state to be
        updated.
      weights_delta: The model delta in global trainable variables from clients.
      aggregator_state: The state of the aggregator after performing
        aggregation.
      broadcaster_state: The state of the broadcaster after broadcasting.

    Returns:
      The updated `tff.learning.framework.ServerState`.
    """
        with tf.init_scope():
            model = model_fn()
        global_model_weights = reconstruction_utils.get_global_variables(model)
        optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            server_optimizer_fn,
            global_model_weights.trainable,
            disjoint_init_and_next=True)
        optimizer_state = server_state.optimizer_state

        # Initialize the model with the current state.
        tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights,
                              server_state.model)

        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        # We ignore the update if the weights_delta is non finite.
        if tf.equal(has_non_finite_weight, 0):
            negative_weights_delta = tf.nest.map_structure(
                lambda w: -1.0 * w, weights_delta)
            optimizer_state, updated_weights = optimizer.next(
                optimizer_state, global_model_weights.trainable,
                negative_weights_delta)
            if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      global_model_weights.trainable,
                                      updated_weights)

        # Create a new state based on the updated model.
        return structure.update_struct(
            server_state,
            model=global_model_weights,
            optimizer_state=optimizer_state,
            model_broadcast_state=broadcaster_state,
            delta_aggregate_state=aggregator_state,
        )
예제 #13
0
    def __call__(self, dataset, initial_weights):
        model = self._model
        optimizer = self._optimizer
        tf.nest.map_structure(lambda a, b: a.assign(b), model.weights,
                              initial_weights)

        def reduce_fn(num_examples_sum, batch):
            """Train `tff.learning.Model` on local client batch."""
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model.weights.trainable)
            optimizer.apply_gradients(zip(gradients, model.weights.trainable))

            if output.num_examples is None:
                return num_examples_sum + tf.shape(output.predictions,
                                                   out_type=tf.int64)[0]
            else:
                return num_examples_sum + tf.cast(output.num_examples,
                                                  tf.int64)

        num_examples_sum = self._dataset_reduce_fn(
            reduce_fn,
            dataset,
            initial_state_fn=lambda: tf.zeros(shape=[], dtype=tf.int64))

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              model.weights.trainable,
                                              initial_weights.trainable)
        model_output = model.report_local_outputs()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        # Zero out the weight if there are any non-finite values.
        if has_non_finite_delta > 0:
            # TODO(b/176171842): Zeroing has no effect with unweighted aggregation.
            weights_delta_weight = tf.constant(0.0)
        elif self._client_weighting is ClientWeighting.NUM_EXAMPLES:
            weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
        elif self._client_weighting is ClientWeighting.UNIFORM:
            weights_delta_weight = tf.constant(1.0)
        else:
            weights_delta_weight = self._client_weighting(model_output)
        # TODO(b/176245976): TFF `ClientOutput` structure names are confusing.
        optimizer_output = collections.OrderedDict(
            num_examples=num_examples_sum)
        return optimizer_utils.ClientOutput(weights_delta,
                                            weights_delta_weight, model_output,
                                            optimizer_output)
예제 #14
0
def server_update_model(
    server_state: ServerState,
    weights_delta: Collection[tf.Tensor],
    model_fn: _ModelConstructor,
    optimizer_fn: _OptimizerConstructor,
) -> ServerState:
    """Updates `server_state` based on `weights_delta`.

  Args:
    server_state: A `tff.learning.framework.ServerState` namedtuple, the state
      to be updated.
    weights_delta: An update to the trainable variables of the model.
    model_fn: A no-arg function that returns a `tff.learning.Model`. Passing in
      a function ensures any variables are created when server_update_model is
      called, so they can be captured in a specific graph or other context.
    optimizer_fn: A no-arg function that returns a `tf.train.Optimizer`. As with
      model_fn, we pass in a function to control when variables are created.

  Returns:
    An updated `tff.learning.framework.ServerState`.
  """
    py_typecheck.check_type(server_state, ServerState)
    py_typecheck.check_type(weights_delta, collections.Collection)
    model = model_utils.enhance(model_fn())
    optimizer = optimizer_fn()
    apply_delta_fn, optimizer_vars = _build_server_optimizer(model, optimizer)

    # We might have a NaN value e.g. if all of the clients processed
    # had no data, so the denominator in the federated_mean is zero.
    # If we see any NaNs, zero out the whole update.
    no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
        weights_delta)
    # TODO(b/124538167): We should increment a server counter to
    # track the fact a non-finite weights_delta was encountered.

    @tf.function
    def update_model_inner():
        """Applies the update."""
        tf.nest.map_structure(
            lambda a, b: a.assign(b), (model.weights, optimizer_vars),
            (server_state.model, server_state.optimizer_state))
        apply_delta_fn(no_nan_weights_delta)
        return model.weights, optimizer_vars

    model_weights, optimizer_vars = update_model_inner()
    # TODO(b/123092620): We must do this outside of the above tf.function, because
    # there could be an AnonymousTuple hiding in server_state,
    # and tf.function's can't return AnonymousTuples.
    return tff.utils.update_state(server_state,
                                  model=model_weights,
                                  optimizer_state=optimizer_vars)
예제 #15
0
    def client_update(initial_weights, dataset):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(state, batch):
            """Runs forward_pass on batch and sums the weighted gradients."""
            accumulated_gradients, num_examples_sum = state

            with tf.GradientTape() as tape:
                output = model.forward_pass(batch)
            gradients = tape.gradient(output.loss, model_weights.trainable)
            num_examples = tf.cast(output.num_examples, tf.float32)
            accumulated_gradients = tuple(
                accumulator + num_examples * gradient
                for accumulator, gradient in zip(accumulated_gradients,
                                                 gradients))

            # We may be able to optimize the reduce function to avoid doubling the
            # number of required variables here (e.g. keeping two copies of all
            # gradients). If you're looking to optimize memory usage this might be a
            # place to look.
            return (accumulated_gradients, num_examples_sum + num_examples)

        def _zero_initial_state():
            """Create a tuple of (gradient accumulators, num examples)."""
            return tuple(
                tf.nest.map_structure(tf.zeros_like,
                                      model_weights.trainable)), tf.constant(
                                          0, dtype=tf.float32)

        gradient_sums, num_examples_sum = dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=_zero_initial_state)

        # We now normalize to compute the average gradient over all examples.
        average_gradient = tf.nest.map_structure(
            lambda gradient: gradient / num_examples_sum, gradient_sums)

        model_output = model.report_local_unfinalized_metrics()
        average_gradient, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(average_gradient))
        if has_non_finite_delta > 0:
            client_weight = tf.constant(0.0)
        else:
            client_weight = num_examples_sum

        return client_works.ClientResult(
            update=average_gradient, update_weight=client_weight), model_output
예제 #16
0
    def client_update(optimizer, initial_weights, data):
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(num_examples_sum, batch):
            """Trains a `tff.learning.Model` on a batch of data."""
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model_weights.trainable)
            if delta_l2_regularizer > 0.0:
                proximal_term = tf.nest.map_structure(
                    lambda x, y: delta_l2_regularizer * (y - x),
                    model_weights.trainable, initial_weights.trainable)
                gradients = tf.nest.map_structure(tf.add, gradients,
                                                  proximal_term)
            grads_and_vars = zip(gradients, model_weights.trainable)
            optimizer.apply_gradients(grads_and_vars)

            # TODO(b/199782787): Add a unit test for a model that does not compute
            # `num_examples` in its forward pass.
            if output.num_examples is None:
                num_examples_sum += tf.shape(output.predictions,
                                             out_type=tf.int64)[0]
            else:
                num_examples_sum += tf.cast(output.num_examples, tf.int64)

            return num_examples_sum

        def initial_state_for_reduce_fn():
            return tf.zeros(shape=[], dtype=tf.int64)

        num_examples = dataset_reduce_fn(
            reduce_fn, data, initial_state_fn=initial_state_for_reduce_fn)
        client_update = tf.nest.map_structure(tf.subtract,
                                              initial_weights.trainable,
                                              model_weights.trainable)
        model_output = model.report_local_unfinalized_metrics()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        client_update, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(client_update))
        client_weight = _choose_client_weight(weighting, has_non_finite_delta,
                                              num_examples)
        return client_works.ClientResult(
            update=client_update, update_weight=client_weight), model_output
예제 #17
0
    def __call__(self, dataset, initial_weights):
        # TODO(b/113112108): Remove this temporary workaround and restore check for
        # `tf.data.Dataset` after subclassing the currently used custom data set
        # representation from it.
        if 'Dataset' not in str(type(dataset)):
            raise TypeError('Expected a data set, found {}.'.format(
                py_typecheck.type_string(type(dataset))))

        model = self._model
        tf.nest.map_structure(lambda a, b: a.assign(b), model.weights,
                              initial_weights)

        @tf.function
        def reduce_fn(num_examples_sum, batch):
            """Runs `tff.learning.Model.train_on_batch` on local client batch."""
            output = model.train_on_batch(batch)
            if output.num_examples is None:
                return num_examples_sum + tf.shape(output.predictions)[0]
            else:
                return num_examples_sum + output.num_examples

        num_examples_sum = dataset.reduce(initial_state=tf.constant(0),
                                          reduce_func=reduce_fn)

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              model.weights.trainable,
                                              initial_weights.trainable)
        aggregated_outputs = model.report_local_outputs()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        if self._client_weight_fn is None:
            weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
        else:
            weights_delta_weight = self._client_weight_fn(aggregated_outputs)
        # Zero out the weight if there are any non-finite values.
        if has_non_finite_delta > 0:
            weights_delta_weight = tf.constant(0.0)

        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, aggregated_outputs,
            tensor_utils.to_odict({
                'num_examples': num_examples_sum,
                'has_non_finite_delta': has_non_finite_delta,
            }))
예제 #18
0
    def __call__(self, dataset, initial_weights):
        model = self._model
        optimizer = self._optimizer
        tf.nest.map_structure(lambda a, b: a.assign(b), model.weights,
                              initial_weights)

        @tf.function
        def reduce_fn(num_examples_sum, batch):
            """Runs `tff.learning.Model.train_on_batch` on local client batch."""
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model.weights.trainable)
            optimizer.apply_gradients(zip(gradients, model.weights.trainable))

            if output.num_examples is None:
                return num_examples_sum + tf.shape(output.predictions)[0]
            else:
                return num_examples_sum + output.num_examples

        num_examples_sum = dataset.reduce(initial_state=tf.constant(0),
                                          reduce_func=reduce_fn)

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              model.weights.trainable,
                                              initial_weights.trainable)
        aggregated_outputs = model.report_local_outputs()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        # Zero out the weight if there are any non-finite values.
        if has_non_finite_delta > 0:
            weights_delta_weight = tf.constant(0.0)
        elif self._client_weight_fn is None:
            weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
        else:
            weights_delta_weight = self._client_weight_fn(aggregated_outputs)

        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, aggregated_outputs,
            collections.OrderedDict(
                num_examples=num_examples_sum,
                has_non_finite_delta=has_non_finite_delta,
            ))
예제 #19
0
    def client_update(model, dataset, initial_weights, client_optimizer):
        """Updates client model.

    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      initial_weights: A `tff.learning.ModelWeights` from server.
      client_optimizer: A `tf.keras.optimizer.Optimizer` object.

    Returns:
      A 'ClientOutput`.
    """

        model_weights = _get_weights(model)
        tff.utils.assign(model_weights, initial_weights)

        num_examples = tf.constant(0, dtype=tf.int32)
        for batch in dataset:
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch)
            grads = tape.gradient(output.loss, model_weights.trainable)
            grads_and_vars = zip(grads, model_weights.trainable)
            client_optimizer.apply_gradients(grads_and_vars)
            num_examples += tf.shape(output.predictions)[0]

        aggregated_outputs = model.report_local_outputs()
        weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                              model_weights.trainable,
                                              initial_weights.trainable)
        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        if has_non_finite_weight > 0:
            client_weight = tf.constant(0, dtype=tf.float32)
        else:
            client_weight = tf.constant(1, dtype=tf.float32)
        #elif client_weight_fn is None:
        #client_weight = tf.cast(num_examples, dtype=tf.float32)
        #else:
        #client_weight = client_weight_fn(aggregated_outputs)

        return ClientOutput(
            weights_delta, client_weight, aggregated_outputs,
            collections.OrderedDict([('num_examples', num_examples)]))
예제 #20
0
    def client_update_fn(optimizer, initial_weights, dataset):
        initial_trainable_weights = initial_weights[0]
        trainable_tensor_specs = tf.nest.map_structure(
            tf.TensorSpec.from_tensor,
            tuple(tf.nest.flatten(initial_trainable_weights)))
        optimizer_state = optimizer.initialize(trainable_tensor_specs)

        # Autograph requires we define these variables once outside the loop.
        model_weights = initial_weights
        trainable_weights, non_trainable_weights = model_weights
        num_examples = tf.constant(0, tf.int64)
        for batch in iter(dataset):
            trainable_weights, non_trainable_weights = model_weights
            with tf.GradientTape() as tape:
                # Must explicitly watch non-variable tensors.
                tape.watch(trainable_weights)
                output = model.forward_pass(model_weights,
                                            batch,
                                            training=True)
            gradients = tape.gradient(output.loss, trainable_weights)
            if tf.greater(delta_l2_regularizer, 0.0):
                proximal_term = tf.nest.map_structure(
                    lambda x, y: delta_l2_regularizer * (y - x),
                    trainable_weights, initial_trainable_weights)
                gradients = tf.nest.map_structure(tf.add, gradients,
                                                  proximal_term)
            optimizer_state, trainable_weights = optimizer.next(
                optimizer_state, trainable_weights, gradients)
            num_examples += tf.cast(output.num_examples, tf.int64)
            model_weights = (trainable_weights, non_trainable_weights)
        # After all local batches, compute the delta between the trained model
        # and the initial incoming model weights.
        client_model_update = tf.nest.map_structure(tf.subtract,
                                                    initial_trainable_weights,
                                                    trainable_weights)
        # TODO(b/229612282): Implement metrics.
        model_output = ()
        client_model_update, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(client_model_update))
        client_weight = _choose_client_weight(weighting, has_non_finite_delta,
                                              num_examples)
        return client_works.ClientResult(
            update=client_model_update,
            update_weight=client_weight), model_output
예제 #21
0
    def update_model_inner(weights_delta):
      """Applies the update to the global model."""
      model_variables = model_utils.ModelWeights.from_model(model)
      optimizer_variables = optimizer.variables()
      # We might have a NaN value e.g. if all of the clients processed
      # had no data, so the denominator in the federated_mean is zero.
      # If we see any NaNs, zero out the whole update.
      no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
          weights_delta)

      # TODO(b/124538167): We should increment a server counter to
      # track the fact a non-finite weights_delta was encountered.

      # Set the variables to the current global model (before update).
      tf.nest.map_structure(lambda a, b: a.assign(b),
                            (model_variables, optimizer_variables),
                            (global_model, optimizer_state))
      # Update the variables with the delta, and return the new global model.
      _apply_delta(optimizer=optimizer, model=model, delta=no_nan_weights_delta)
      return model_variables, optimizer_variables
    def __call__(self, dataset, initial_weights):
        del initial_weights
        model = self._model

        @tf.function
        def reduce_fn_num_examples(num_examples_sum, batch):
            """Count number of examples."""
            num_examples_in_batch = tf.shape(batch['x'])[0]
            return num_examples_sum + num_examples_in_batch

        @tf.function
        def reduce_fn_dataset_mean(sum_vector, batch):
            """Sum all the examples in the local dataset."""
            sum_batch = tf.reshape(tf.reduce_sum(batch['x'], [0]), (-1, 1))
            return sum_vector + sum_batch

        num_examples_sum = dataset.reduce(initial_state=tf.constant(0),
                                          reduce_func=reduce_fn_num_examples)

        example_vector_sum = dataset.reduce(initial_state=tf.zeros((DIM, 1)),
                                            reduce_func=reduce_fn_dataset_mean)

        # create an ordered dictionary with the same type as model.trainable
        # containing a mean of all the examples in the local dataset
        # Note: this works for a linear model only (as in the example above)
        key = list(model.weights.trainable.keys())[0]
        weights_delta = collections.OrderedDict(
            {key: example_vector_sum / tf.cast(num_examples_sum, tf.float32)})

        aggregated_outputs = model.report_local_outputs()
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        weights_delta_weight = tf.cast(num_examples_sum, tf.float32)

        return tff.learning.framework.ClientOutput(
            weights_delta, weights_delta_weight, aggregated_outputs,
            collections.OrderedDict([
                ('num_examples', num_examples_sum),
                ('has_non_finite_delta', has_non_finite_delta),
            ]))
예제 #23
0
def server_update_model(current_server_state, weights_delta, model_fn,
                        optimizer_fn):
  """Updates `server_state` based on `weights_delta`.

  Args:
    current_server_state: A `tff.learning.framework.ServerState` namedtuple.
    weights_delta: An update to the trainable variables of the model.
    model_fn: A no-arg function that returns a `tff.learning.Model`. Passing in
      a function ensures any variables are created when server_update_model is
      called, so they can be captured in a specific graph or other context.
    optimizer_fn: A no-arg function that returns a `tf.train.Optimizer`. As with
      model_fn, we pass in a function to control when variables are created.

  Returns:
    An updated `tff.learning.framework.ServerState`.
  """
  py_typecheck.check_type(current_server_state, ServerState)
  py_typecheck.check_type(weights_delta, collections.OrderedDict)
  model = model_utils.enhance(model_fn())
  optimizer = optimizer_fn()
  apply_delta_fn, server_state_vars = _create_optimizer_and_server_state(
      model, optimizer)

  # We might have a NaN value e.g. if all of the clients processed
  # had no data, so the denominator in the federated_mean is zero.
  # If we see any NaNs, zero out the whole update.
  no_nan_weights_delta, _ = tensor_utils.zero_all_if_any_non_finite(
      weights_delta)
  # TODO(b/124538167): We should increment a server counter to
  # track the fact a non-finite weiths_delta was encountered.

  @tf.contrib.eager.function(autograph=False)
  def update_model_inner():
    """Applies the update."""
    nest.map_structure(tf.assign, server_state_vars, current_server_state)
    apply_delta_fn(no_nan_weights_delta)
    return server_state_vars

  return update_model_inner()
예제 #24
0
    def __call__(self, dataset, initial_weights):
        model = self._model

        # TODO(b/113112108): Remove this temporary workaround and restore check for
        # `tf.data.Dataset` after subclassing the currently used custom data set
        # representation from it.
        if 'Dataset' not in str(type(dataset)):
            raise TypeError('Expected a data set, found {}.'.format(
                py_typecheck.type_string(type(dataset))))

        tf.nest.map_structure(lambda a, b: a.assign(b), model.weights,
                              initial_weights)
        flat_trainable_weights = tuple(tf.nest.flatten(
            model.weights.trainable))

        @tf.function
        def reduce_fn(state, batch):
            """Runs forward_pass on batch and sums the weighted gradients."""
            flat_accumulated_grads, batch_weight_sum = state

            with tf.GradientTape() as tape:
                output = model.forward_pass(batch)
            flat_grads = tape.gradient(output.loss, flat_trainable_weights)

            if self._batch_weight_fn is not None:
                batch_weight = self._batch_weight_fn(batch)
            else:
                batch_weight = tf.cast(
                    tf.shape(output.predictions)[0], tf.float32)

            flat_accumulated_grads = tuple(
                accumulator + batch_weight * grad for accumulator, grad in zip(
                    flat_accumulated_grads, flat_grads))

            # The TF team is aware of an optimization in the reduce state to avoid
            # doubling the number of required variables here (e.g. keeping two copies
            # of all gradients). If you're looking to optimize memory usage this might
            # be a place to look.
            return (flat_accumulated_grads, batch_weight_sum + batch_weight)

        def _zero_initial_state():
            """Create a tuple of (tuple of gradient accumulators, batch weight sum)."""
            return (tuple(tf.zeros_like(w)
                          for w in flat_trainable_weights), tf.constant(0.0))

        flat_grad_sums, batch_weight_sum = self._dataset_reduce_fn(
            reduce_fn=reduce_fn,
            dataset=dataset,
            initial_state_fn=_zero_initial_state)
        grad_sums = tf.nest.pack_sequence_as(model.weights.trainable,
                                             flat_grad_sums)

        # For SGD, the delta is just the negative of the average gradient:
        weights_delta = tf.nest.map_structure(
            lambda gradient: -1.0 * gradient / batch_weight_sum, grad_sums)
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        if has_non_finite_delta > 0:
            weights_delta_weight = tf.constant(0.0)
        else:
            weights_delta_weight = batch_weight_sum
        return optimizer_utils.ClientOutput(
            weights_delta, weights_delta_weight, model.report_local_outputs(),
            tensor_utils.to_odict({
                'client_weight': weights_delta_weight,
                'has_non_finite_delta': has_non_finite_delta,
            }))
예제 #25
0
    def client_update(dataset, initial_model_weights):
        """Performs client local model optimization.

    Args:
      dataset: A `tf.data.Dataset` that provides training examples.
      initial_model_weights: A `tff.learning.ModelWeights` containing the
        starting global trainable and non-trainable weights.

    Returns:
      A `ClientOutput`.
    """
        with tf.init_scope():
            model = model_fn()

            metrics = []
            if metrics_fn is not None:
                metrics.extend(metrics_fn())
            # To be used to calculate example-weighted mean across batches and
            # clients.
            metrics.append(keras_utils.MeanLossMetric(loss_fn()))
            # To be used to calculate batch loss for model updates.
            client_loss = loss_fn()

        global_model_weights = reconstruction_utils.get_global_variables(model)
        local_model_weights = reconstruction_utils.get_local_variables(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights,
                              initial_model_weights)
        client_optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            client_optimizer_fn,
            global_model_weights.trainable,
            disjoint_init_and_next=False)
        reconstruction_optimizer = keras_optimizer.build_or_verify_tff_optimizer(
            reconstruction_optimizer_fn,
            local_model_weights.trainable,
            disjoint_init_and_next=False)

        @tf.function
        def reconstruction_reduce_fn(state, batch):
            """Runs reconstruction training on local client batch."""
            num_examples_sum, optimizer_state = state
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)
                batch_loss = client_loss(y_true=output.labels,
                                         y_pred=output.predictions)

            gradients = tape.gradient(batch_loss,
                                      local_model_weights.trainable)
            optimizer_state, updated_weights = reconstruction_optimizer.next(
                optimizer_state, local_model_weights.trainable, gradients)
            if not isinstance(reconstruction_optimizer,
                              keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      local_model_weights.trainable,
                                      updated_weights)

            return num_examples_sum + output.num_examples

        @tf.function
        def train_reduce_fn(state, batch):
            """Runs one step of client optimizer on local client batch."""
            num_examples_sum, optimizer_state = state
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)
                batch_loss = client_loss(y_true=output.labels,
                                         y_pred=output.predictions)

            gradients = tape.gradient(batch_loss,
                                      global_model_weights.trainable)
            optimizer_state, updated_weights = client_optimizer.next(
                optimizer_state, global_model_weights.trainable, gradients)
            if not isinstance(client_optimizer,
                              keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      global_model_weights.trainable,
                                      updated_weights)

            # Update each metric.
            for metric in metrics:
                metric.update_state(y_true=output.labels,
                                    y_pred=output.predictions)

            return num_examples_sum + output.num_examples

        recon_dataset, post_recon_dataset = dataset_split_fn(dataset)

        # If needed, do reconstruction, training the local variables while keeping
        # the global ones frozen.
        if local_model_weights.trainable:
            # Ignore output number of examples used in reconstruction, since this
            # isn't included in `client_weight`.
            def initial_state_reconstruction_reduce():
                trainable_tensor_specs = tf.nest.map_structure(
                    lambda v: tf.TensorSpec(v.shape, v.dtype),
                    local_model_weights.trainable)
                return tf.constant(0), reconstruction_optimizer.initialize(
                    trainable_tensor_specs)

            recon_dataset.reduce(
                initial_state=initial_state_reconstruction_reduce(),
                reduce_func=reconstruction_reduce_fn)

        # Train the global variables, keeping local variables frozen.
        def initial_state_train_reduce():
            trainable_tensor_specs = tf.nest.map_structure(
                lambda v: tf.TensorSpec(v.shape, v.dtype),
                global_model_weights.trainable)
            return tf.constant(0), client_optimizer.initialize(
                trainable_tensor_specs)

        num_examples_sum, _ = post_recon_dataset.reduce(
            initial_state=initial_state_train_reduce(),
            reduce_func=train_reduce_fn)

        weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                              global_model_weights.trainable,
                                              initial_model_weights.trainable)

        # We ignore the update if the weights_delta is non finite.
        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        model_local_outputs = keras_utils.read_metric_variables(metrics)

        if has_non_finite_weight > 0:
            client_weight = tf.constant(0.0, dtype=tf.float32)
        elif client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES:
            client_weight = tf.cast(num_examples_sum, dtype=tf.float32)
        elif client_weighting is client_weight_lib.ClientWeighting.UNIFORM:
            client_weight = tf.constant(1.0, dtype=tf.float32)
        else:
            client_weight = client_weighting(model_local_outputs)

        return ClientOutput(weights_delta, client_weight, model_local_outputs)
예제 #26
0
def client_update(model, optimizer, dataset, initial_weights):
    """Updates client model.

  Args:
    model: A `tff.learning.Model`.
    optimizer: A `tf.keras.optimizers.Optimizer`.
    dataset: A 'tf.data.Dataset'.
    initial_weights: A `tff.learning.Model.weights` from server.

  Returns:
    A 'ClientOutput`.
  """
    model_weights = tff.learning.framework.ModelWeights.from_model(model)
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          initial_weights)
    flat_trainable_weights = tuple(tf.nest.flatten(model_weights.trainable))

    @tf.function
    def reduce_fn(state, batch):
        """Train on local client batch, summing the gradients and gradients norm."""
        flat_accumulated_grads, flat_accumulated_grads_norm, batch_weight_sum = state

        # Unliked the FedAvg client update, we need to capture the gradients during
        # training so we can send back the norms to the server.
        with tf.GradientTape() as tape:
            output = model.forward_pass(batch)
        flat_grads = tape.gradient(output.loss, flat_trainable_weights)
        optimizer.apply_gradients(zip(flat_grads, flat_trainable_weights))
        batch_weight = tf.cast(tf.shape(output.predictions)[0],
                               dtype=tf.float32)
        flat_accumulated_grads = tuple(
            accumulator + batch_weight * grad
            for accumulator, grad in zip(flat_accumulated_grads, flat_grads))
        flat_accumulated_grads_norm = tuple(
            norm_accumulator + batch_weight * tf.norm(grad)
            for norm_accumulator, grad in zip(flat_accumulated_grads_norm,
                                              flat_grads))
        return (flat_accumulated_grads, flat_accumulated_grads_norm,
                batch_weight_sum + batch_weight)

    def _zero_initial_state():
        """Create a tuple of (tuple of gradient accumulators, batch weight sum)."""
        return (
            tuple(tf.zeros_like(w) for w in flat_trainable_weights),
            tuple(
                tf.constant(0, dtype=w.dtype) for w in flat_trainable_weights),
            tf.constant(0, dtype=tf.float32),
        )

    flat_grads_sum, flat_grads_norm_sum, batch_weight_sum = dataset.reduce(
        initial_state=_zero_initial_state(), reduce_func=reduce_fn)

    grads_sum = tf.nest.pack_sequence_as(model_weights.trainable,
                                         flat_grads_sum)
    weights_delta = tf.nest.map_structure(
        lambda gradient: -1.0 * gradient / batch_weight_sum, grads_sum)
    flat_grads_norm_sum = tf.nest.map_structure(
        lambda grad_norm: grad_norm / batch_weight_sum, flat_grads_norm_sum)

    weights_delta, has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    # Zero out the weight if there are any non-finite values.
    if has_non_finite_delta > 0:
        weights_delta_weight = tf.constant(0.0)
    else:
        weights_delta_weight = batch_weight_sum

    return ClientOutput(weights_delta,
                        weights_delta_weight,
                        model_output=model.report_local_outputs(),
                        optimizer_output=collections.OrderedDict(
                            num_examples=batch_weight_sum,
                            flat_grads_norm_sum=flat_grads_norm_sum))
예제 #27
0
        def client_update(global_optimizer_state, initial_weights, data):
            model_weights = model_utils.ModelWeights.from_model(model)
            tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                                  initial_weights)

            def full_gradient_reduce_fn(state, batch):
                """Sums individual gradients, to be later divided by num_examples."""
                gradient_sum, num_examples_sum = state
                with tf.GradientTape() as tape:
                    output = model.forward_pass(batch, training=True)
                if output.num_examples is None:
                    num_examples = tf.shape(output.predictions,
                                            out_type=tf.int64)[0]
                else:
                    num_examples = tf.cast(output.num_examples, tf.int64)
                # TODO(b/161529310): We flatten and convert to tuple, as tf.data
                # iterators would try to stack the tensors in list into a single tensor.
                gradients = tuple(
                    tf.nest.flatten(
                        tape.gradient(output.loss, model_weights.trainable)))
                gradient_sum = tf.nest.map_structure(
                    lambda g_sum, g: g_sum + g * tf.cast(
                        num_examples, g.dtype), gradient_sum, gradients)
                num_examples_sum += num_examples
                return gradient_sum, num_examples_sum

            def initial_state_for_full_gradient_reduce_fn():
                initial_gradient_sum = tf.nest.map_structure(
                    lambda spec: tf.zeros(spec.shape, spec.dtype),
                    tuple(tf.nest.flatten(weight_tensor_specs.trainable)))
                initial_num_examples_sum = tf.constant(0, tf.int64)
                return initial_gradient_sum, initial_num_examples_sum

            full_gradient, num_examples = dataset_reduce_fn(
                full_gradient_reduce_fn, data,
                initial_state_for_full_gradient_reduce_fn)
            # Compute the average gradient.
            full_gradient = tf.nest.map_structure(
                lambda g: tf.math.divide_no_nan(
                    g, tf.cast(num_examples, g.dtype)), full_gradient)

            # Resets the local model variables, including metrics states, as we are
            # not interested in metrics based on the full gradient evaluation, only
            # from the subsequent training.
            model.reset_metrics()

            def train_reduce_fn(state, batch):
                with tf.GradientTape() as tape:
                    output = model.forward_pass(batch, training=True)
                gradients = tape.gradient(output.loss, model_weights.trainable)
                # Mime Lite keeps optimizer state unchanged during local training.
                _, updated_weights = optimizer.next(global_optimizer_state,
                                                    model_weights.trainable,
                                                    gradients)
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      model_weights.trainable, updated_weights)
                return state

            # Performs local training, updating `tf.Variable`s in `model_weights`.
            dataset_reduce_fn(train_reduce_fn,
                              data,
                              initial_state_fn=lambda: tf.zeros(shape=[0]))

            client_weights_delta = tf.nest.map_structure(
                tf.subtract, initial_weights.trainable,
                model_weights.trainable)
            model_output = model.report_local_unfinalized_metrics()

            # TODO(b/122071074): Consider moving this functionality into aggregation.
            client_weights_delta, has_non_finite_delta = (
                tensor_utils.zero_all_if_any_non_finite(client_weights_delta))
            client_weight = _choose_client_weight(client_weighting,
                                                  has_non_finite_delta,
                                                  num_examples)
            return client_works.ClientResult(
                update=client_weights_delta,
                update_weight=client_weight), model_output, full_gradient
    def __call__(self, dataset, initial_weights):
        model = self._model
        optimizer = self._optimizer
        model_weights = model_utils.ModelWeights.from_model(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                              initial_weights)

        def reduce_fn(state, batch):
            """Train `tff.learning.Model` on local client batch."""
            num_examples_sum, optimizer_state = state

            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)

            gradients = tape.gradient(output.loss, model_weights.trainable)
            optimizer_state, updated_weights = optimizer.next(
                optimizer_state, model_weights.trainable, gradients)
            if not isinstance(optimizer, keras_optimizer.KerasOptimizer):
                # Keras optimizer mutates model variables within the `next` step.
                tf.nest.map_structure(lambda a, b: a.assign(b),
                                      model_weights.trainable, updated_weights)

            if output.num_examples is None:
                num_examples_sum += tf.shape(output.predictions,
                                             out_type=tf.int64)[0]
            else:
                num_examples_sum += tf.cast(output.num_examples, tf.int64)

            return num_examples_sum, optimizer_state

        def initial_state_for_reduce_fn():
            trainable_tensor_specs = tf.nest.map_structure(
                lambda v: tf.TensorSpec(v.shape, v.dtype),
                model_weights.trainable)
            return tf.zeros(
                shape=[],
                dtype=tf.int64), optimizer.initialize(trainable_tensor_specs)

        num_examples_sum, _ = self._dataset_reduce_fn(
            reduce_fn, dataset, initial_state_fn=initial_state_for_reduce_fn)

        weights_delta = tf.nest.map_structure(tf.subtract,
                                              model_weights.trainable,
                                              initial_weights.trainable)
        model_output = model.report_local_outputs()

        # TODO(b/122071074): Consider moving this functionality into
        # tff.federated_mean?
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        # Zero out the weight if there are any non-finite values.
        if has_non_finite_delta > 0:
            # TODO(b/176171842): Zeroing has no effect with unweighted aggregation.
            weights_delta_weight = tf.constant(0.0)
        elif self._client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES:
            weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
        elif self._client_weighting is client_weight_lib.ClientWeighting.UNIFORM:
            weights_delta_weight = tf.constant(1.0)
        else:
            weights_delta_weight = self._client_weighting(model_output)
        # TODO(b/176245976): TFF `ClientOutput` structure names are confusing.
        optimizer_output = collections.OrderedDict(
            num_examples=num_examples_sum)
        return optimizer_utils.ClientOutput(weights_delta,
                                            weights_delta_weight, model_output,
                                            optimizer_output)