예제 #1
0
def server_update(model, server_optimizer, server_optimizer_vars, server_state,
                  weights_delta, grads_norm):
    """Updates `server_state` based on `weights_delta`.

  Args:
    model: A `tff.learning.Model`.
    server_optimizer: A `tf.keras.optimizers.Optimizer`.
    server_optimizer_vars: A list of previous variables of server_optimzer.
    server_state: A `ServerState` namedtuple, the state to be updated.
    weights_delta: An update to the trainable variables of the model.
    grads_norm: Summation of the norm of gradients from clients.

  Returns:
    An updated `ServerState`.
  """
    model_weights = tff.learning.framework.ModelWeights.from_model(model)
    tf.nest.map_structure(lambda v, t: v.assign(t),
                          (model_weights, server_optimizer_vars),
                          (server_state.model, server_state.optimizer_state))

    # Zero out the weight if there are any non-finite values.
    weights_delta, _ = (tensor_utils.zero_all_if_any_non_finite(weights_delta))

    grads_and_vars = tf.nest.map_structure(
        lambda x, v: (-1.0 * x, v), tf.nest.flatten(weights_delta),
        tf.nest.flatten(model_weights.trainable))

    server_optimizer.update_grads_norm(
        tf.nest.flatten(model_weights.trainable), grads_norm)
    server_optimizer.apply_gradients(grads_and_vars, name='server_update')

    return tff.utils.update_state(server_state,
                                  model=model_weights,
                                  optimizer_state=server_optimizer_vars)
예제 #2
0
def server_update(model, server_optimizer, server_state, weights_delta):
    """Updates `server_state` based on `weights_delta`, increase the round number.

  Args:
    model: A `tff.learning.Model`.
    server_optimizer: A `tf.keras.optimizers.Optimizer`.
    server_state: A `ServerState`, the state to be updated.
    weights_delta: An update to the trainable variables of the model.

  Returns:
    An updated `ServerState`.
  """
    model_weights = _get_weights(model)
    tff.utils.assign(model_weights, server_state.model)
    # Server optimizer variables must be initialized prior to invoking this
    tff.utils.assign(server_optimizer.variables(),
                     server_state.optimizer_state)

    weights_delta, has_non_finite_weight = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    if has_non_finite_weight > 0:
        return server_state

    # Apply the update to the model. We must multiply weights_delta by -1.0 to
    # view it as a gradient that should be applied to the server_optimizer.
    grads_and_vars = [(-1.0 * x, v)
                      for x, v in zip(weights_delta, model_weights.trainable)]

    server_optimizer.apply_gradients(grads_and_vars)

    # Create a new state based on the updated model.
    return tff.utils.update_state(server_state,
                                  model=model_weights,
                                  optimizer_state=server_optimizer.variables(),
                                  round_num=server_state.round_num + 1.0)
예제 #3
0
def server_update(model, server_optimizer, server_state, weights_delta):
    """Updates `server_state` based on `weights_delta`.

  Args:
    model: A `KerasModelWrapper` or `tff.learning.Model`.
    server_optimizer: A `ServerOptimizerBase`.
    server_state: A `ServerState`, the state to be updated.
    weights_delta: A nested structure of tensors holding the updates to the
      trainable variables of the model.

  Returns:
    An updated `ServerState`.
  """
    weights_delta, has_non_finite_weight = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    if has_non_finite_weight > 0:
        return server_state

    # Initialize the model with the current state.
    model_weights = model.weights
    tff.utils.assign(model_weights, server_state.model_weights)

    # Apply the update to the model, and return the updated state.
    grad = tf.nest.map_structure(lambda x: -1.0 * x, weights_delta)
    optimizer_state = server_optimizer.model_update(
        state=server_state.optimizer_state,
        weight=model_weights.trainable,
        grad=grad,
        round_idx=server_state.round_num)

    # Create a new state based on the updated model.
    return tff.utils.update_state(server_state,
                                  model_weights=model_weights,
                                  optimizer_state=optimizer_state,
                                  round_num=server_state.round_num + 1)
예제 #4
0
def client_computation(
        # Tensor/Dataset arguments that will be supplied by TFF:
        gen_inputs_ds: tf.data.Dataset,
        real_data_ds: tf.data.Dataset,
        from_server: FromServer,
        # Python arguments bound to be bound at TFF computation construction time:
        generator: tf.keras.Model,
        discriminator: tf.keras.Model,
        train_discriminator_fn) -> ClientOutput:
    """The computation to run on the client, training the discriminator.

  Args:
    gen_inputs_ds: A `tf.data.Dataset` of generator_inputs.
    real_data_ds: A `tf.data.Dataset` of data from the real distribution.
    from_server: A `FromServer` object, including the current model weights.
    generator:  The generator.
    discriminator: The discriminator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.

  Returns:
    A `ClientOutput` object.
  """
    tf.nest.map_structure(lambda a, b: a.assign(b), generator.weights,
                          from_server.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), discriminator.weights,
                          from_server.discriminator_weights)

    num_examples = tf.constant(0)
    gen_inputs_and_real_data = tf.data.Dataset.zip(
        (gen_inputs_ds, real_data_ds))
    for gen_inputs, real_data in gen_inputs_and_real_data:
        # It's possible that real_data and gen_inputs have different batch sizes.
        # For calculating the discriminator loss, it's desirable to have equal-sized
        # contributions from both the real and fake data. Also, it's necessary if
        # using the Wasserstein gradient penalty (where a difference is taken b/w
        # the real and fake data). So here we reduce to the min batch size. This
        # also ensures num_examples properly reflects the amount of data trained on.
        min_batch_size = tf.minimum(
            tf.shape(real_data)[0],
            tf.shape(gen_inputs)[0])
        real_data = real_data[0:min_batch_size]
        gen_inputs = gen_inputs[0:min_batch_size]
        num_examples += train_discriminator_fn(generator, discriminator,
                                               gen_inputs, real_data)

    weights_delta = tf.nest.map_structure(tf.subtract, discriminator.weights,
                                          from_server.discriminator_weights)
    weights_delta, has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    update_weight = tf.cast(num_examples, tf.float32)
    # Zero out the weight if there are any non-finite values.
    # TODO(b/122071074): federated_mean might not do the right thing if
    # all clients have zero weight.
    update_weight = tf.cond(tf.equal(has_non_finite_delta, 0),
                            lambda: update_weight, lambda: tf.constant(0.0))
    return ClientOutput(
        discriminator_weights_delta=weights_delta,
        update_weight=update_weight,
        counters={'num_discriminator_train_examples': num_examples})
 def expect_zeros(structure, expected):
   with tf.Graph().as_default():
     result, error = tensor_utils.zero_all_if_any_non_finite(structure)
     with self.session() as sess:
       result, error = sess.run((result, error))
     try:
       tf.nest.map_structure(np.testing.assert_allclose, result, expected)
     except AssertionError:
       self.fail('Expected to get zeros, but instead got {}'.format(result))
     self.assertEqual(error, 1)
예제 #6
0
  def client_update(model,
                    dataset,
                    initial_weights,
                    client_optimizer,
                    client_weight_fn=None):
    """Updates client model.

    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      initial_weights: A `tff.learning.ModelWeights` from server.
      client_optimizer: A `tf.keras.optimizer.Optimizer` object.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.

    Returns:
      A 'ClientOutput`.
    """

    model_weights = _get_weights(model)
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          initial_weights)

    num_examples = tf.constant(0, dtype=tf.int32)
    for batch in dataset:
      with tf.GradientTape() as tape:
        output = model.forward_pass(batch)
      grads = tape.gradient(output.loss, model_weights.trainable)
      grads_and_vars = zip(grads, model_weights.trainable)
      client_optimizer.apply_gradients(grads_and_vars)
      num_examples += tf.shape(output.predictions)[0]

    aggregated_outputs = model.report_local_outputs()
    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          model_weights.trainable,
                                          initial_weights.trainable)
    weights_delta, has_non_finite_weight = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))

    if has_non_finite_weight > 0:
      client_weight = tf.constant(0, dtype=tf.float32)
    elif client_weight_fn is None:
      client_weight = tf.cast(num_examples, dtype=tf.float32)
    else:
      client_weight = client_weight_fn(aggregated_outputs)

    return ClientOutput(
        weights_delta, client_weight, aggregated_outputs,
        collections.OrderedDict([('num_examples', num_examples)]))
예제 #7
0
    def server_update(model, server_optimizer, server_optimizer_vars,
                      server_state, weights_delta, aggregator_state):
        """Updates `server_state` based on `weights_delta`.

    Args:
      model: A `ReconstructionModel`.
      server_optimizer: A `tf.keras.optimizers.Optimizer`.
      server_optimizer_vars: A list of variables of server_optimizer.
      server_state: A `ServerState`, the state to be updated.
      weights_delta: An update to the trainable variables of the model.
      aggregator_state: The state of the aggregator after performing
        aggregation.

    Returns:
      An updated `ServerState`.
    """
        global_model_weights = reconstruction_utils.get_global_variables(model)
        # Initialize the model with the current state.
        tf.nest.map_structure(
            lambda a, b: a.assign(b),
            (global_model_weights, server_optimizer_vars),
            (server_state.model, server_state.optimizer_state))

        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        # We ignore the update if the weights_delta is non finite.
        if has_non_finite_weight > 0:
            return tff.utils.update_state(
                server_state,
                model=global_model_weights,
                optimizer_state=server_optimizer_vars,
                round_num=server_state.round_num + 1,
                aggregator_state=aggregator_state)

        # Apply the update to the model.
        grads_and_vars = tf.nest.map_structure(
            lambda x, v: (-1.0 * x, v), tf.nest.flatten(weights_delta),
            tf.nest.flatten(global_model_weights.trainable))
        server_optimizer.apply_gradients(grads_and_vars, name='server_update')

        # Create a new state based on the updated model.
        return tff.utils.update_state(
            server_state,
            model=global_model_weights,
            optimizer_state=server_optimizer_vars,
            round_num=server_state.round_num + 1,
            aggregator_state=aggregator_state,
        )
def server_update(model,
                  server_optimizer,
                  server_state,
                  weights_delta,
                  global_cor=None):
    """Updates `server_state` based on `weights_delta`, increase the round number.

  Args:
    model: A `tff.learning.Model`.
    server_optimizer: A `tf.keras.optimizers.Optimizer`.
    server_state: A `ServerState`, the state to be updated.
    weights_delta: An update to the trainable variables of the model.
    global_cor: Optional. A correction to the update of `weights_delta`.

  Returns:
    An updated `ServerState`.
  """
    model_weights = _get_weights(model)
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          server_state.model)
    # Server optimizer variables must be initialized prior to invoking this
    tf.nest.map_structure(lambda v, t: v.assign(t),
                          server_optimizer.variables(),
                          server_state.optimizer_state)

    weights_delta, has_non_finite_weight = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    if has_non_finite_weight > 0:
        return server_state

    if global_cor is not None:
        weights_delta = tf.nest.map_structure(tf.math.divide_no_nan,
                                              weights_delta, global_cor)

    # Apply the update to the model. We must multiply weights_delta by -1.0 to
    # view it as a gradient that should be applied to the server_optimizer.
    grads_and_vars = [(-1.0 * x, v)
                      for x, v in zip(weights_delta, model_weights.trainable)]

    server_optimizer.apply_gradients(grads_and_vars)

    # Create a new state based on the updated model.
    return tff.structure.update_struct(
        server_state,
        model=model_weights,
        optimizer_state=server_optimizer.variables(),
        round_num=server_state.round_num + 1.0)
    def __call__(self, dataset, initial_weights):
        del initial_weights
        model = self._model

        @tf.function
        def reduce_fn_num_examples(num_examples_sum, batch):
            """Count number of examples."""
            num_examples_in_batch = tf.shape(batch['x'])[0]
            return num_examples_sum + num_examples_in_batch

        @tf.function
        def reduce_fn_dataset_mean(sum_vector, batch):
            """Sum all the examples in the local dataset."""
            sum_batch = tf.reshape(tf.reduce_sum(batch['x'], [0]), (-1, 1))
            return sum_vector + sum_batch

        num_examples_sum = dataset.reduce(initial_state=tf.constant(0),
                                          reduce_func=reduce_fn_num_examples)
        example_vector_sum = dataset.reduce(initial_state=tf.zeros((DIM, 1)),
                                            reduce_func=reduce_fn_dataset_mean)

        # create a list with the same structure and type as model.trainable
        # containing a mean of all the examples in the local dataset. Note: this
        # works for a linear model only (as in the example above)
        weights_delta = [
            example_vector_sum / tf.cast(num_examples_sum, tf.float32)
        ]
        aggregated_outputs = model.report_local_outputs()
        weights_delta, has_non_finite_delta = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))
        weights_delta_weight = tf.cast(num_examples_sum, tf.float32)

        return tff.learning.framework.ClientOutput(
            weights_delta, weights_delta_weight, aggregated_outputs,
            collections.OrderedDict(
                num_examples=num_examples_sum,
                has_non_finite_delta=has_non_finite_delta,
            ))
예제 #10
0
def client_update(model, optimizer, dataset, initial_weights):
    """Updates client model.

  Args:
    model: A `tff.learning.Model`.
    optimizer: A `tf.keras.optimizers.Optimizer`.
    dataset: A 'tf.data.Dataset'.
    initial_weights: A `tff.learning.Model.weights` from server.

  Returns:
    A 'ClientOutput`.
  """
    model_weights = tff.learning.framework.ModelWeights.from_model(model)
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          initial_weights)
    flat_trainable_weights = tuple(tf.nest.flatten(model_weights.trainable))

    @tf.function
    def reduce_fn(state, batch):
        """Train on local client batch, summing the gradients and gradients norm."""
        flat_accumulated_grads, flat_accumulated_grads_norm, batch_weight_sum = state

        # Unliked the FedAvg client update, we need to capture the gradients during
        # training so we can send back the norms to the server.
        with tf.GradientTape() as tape:
            output = model.forward_pass(batch)
        flat_grads = tape.gradient(output.loss, flat_trainable_weights)
        optimizer.apply_gradients(zip(flat_grads, flat_trainable_weights))
        batch_weight = tf.cast(tf.shape(output.predictions)[0],
                               dtype=tf.float32)
        flat_accumulated_grads = tuple(
            accumulator + batch_weight * grad
            for accumulator, grad in zip(flat_accumulated_grads, flat_grads))
        flat_accumulated_grads_norm = tuple(
            norm_accumulator + batch_weight * tf.norm(grad)
            for norm_accumulator, grad in zip(flat_accumulated_grads_norm,
                                              flat_grads))
        return (flat_accumulated_grads, flat_accumulated_grads_norm,
                batch_weight_sum + batch_weight)

    def _zero_initial_state():
        """Create a tuple of (tuple of gradient accumulators, batch weight sum)."""
        return (
            tuple(tf.zeros_like(w) for w in flat_trainable_weights),
            tuple(
                tf.constant(0, dtype=w.dtype) for w in flat_trainable_weights),
            tf.constant(0, dtype=tf.float32),
        )

    flat_grads_sum, flat_grads_norm_sum, batch_weight_sum = dataset.reduce(
        initial_state=_zero_initial_state(), reduce_func=reduce_fn)

    grads_sum = tf.nest.pack_sequence_as(model_weights.trainable,
                                         flat_grads_sum)
    weights_delta = tf.nest.map_structure(
        lambda gradient: -1.0 * gradient / batch_weight_sum, grads_sum)
    flat_grads_norm_sum = tf.nest.map_structure(
        lambda grad_norm: grad_norm / batch_weight_sum, flat_grads_norm_sum)

    weights_delta, has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))
    # Zero out the weight if there are any non-finite values.
    if has_non_finite_delta > 0:
        weights_delta_weight = tf.constant(0.0)
    else:
        weights_delta_weight = batch_weight_sum

    return ClientOutput(weights_delta,
                        weights_delta_weight,
                        model_output=model.report_local_outputs(),
                        optimizer_output=collections.OrderedDict(
                            num_examples=batch_weight_sum,
                            flat_grads_norm_sum=flat_grads_norm_sum))
예제 #11
0
def client_update(model,
                  dataset,
                  num_epochs,
                  initial_weights,
                  client_optimizer,
                  client_mixedin_fn,
                  client_update_delta_fn,
                  client_single_data_pass_fn,
                  client_weight_fn=None):
    """Updates client model.

  Args:
    model: A `tff.learning.Model`.
    dataset: A 'tf.data.Dataset'.
    num_epochs: The number of epochs or dataset passes.
    initial_weights: A `tff.learning.Model.weights` from server.
    client_optimizer: A `tf.keras.optimizer.Optimizer` object.
    client_mixedin_fn: A function that takes the outputs of the previous and
      current epoch and returns a boolean indicating whether the SG-MCMC has
      mixed in, in which case the following epochs can be used to produce
      approximate posterior samples.
    client_update_delta_fn: A function for updating the weights delta as new
      posterior samples become available.
    client_single_data_pass_fn: A function for taking a single pass over the
      client data to update the model and compute necessary outputs.
    client_weight_fn: Optional function that takes the output of
      `model.report_local_outputs` and returns a tensor that provides the weight
      in the federated average of model deltas. If not provided, the default is
      the total number of examples processed on device.

  Returns:
    A 'ClientOutput`.
  """
    model_weights = _get_weights(model)
    initial_weights = tff.learning.ModelWeights(
        trainable=tuple(initial_weights.trainable),
        non_trainable=tuple(initial_weights.non_trainable))
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          initial_weights)

    # Initialize updates.
    mixedin = tf.constant(False, dtype=tf.bool)
    updates = DeltaUpdateOutput.from_weights(
        initial_weights=initial_weights.trainable,
        updated_weights=initial_weights.trainable)

    # Keep iterating over the data and refining weight deltas.
    num_examples = 0.0
    for epoch in tf.range(num_epochs):
        outputs = client_single_data_pass_fn(model=model,
                                             dataset=dataset,
                                             client_optimizer=client_optimizer)
        mixedin = client_mixedin_fn(epoch, mixedin, outputs)
        updates = client_update_delta_fn(mixedin=mixedin,
                                         initial_weights=initial_weights,
                                         data_pass_outputs=outputs,
                                         previous_updates=updates)
        num_examples = outputs.num_examples

    # Check for non-finite weights.
    weights_delta, has_non_finite_weight = (
        tensor_utils.zero_all_if_any_non_finite(updates.weights_delta))
    model_output = model.report_local_outputs()
    optimizer_output = collections.OrderedDict(num_examples=num_examples)
    weights_delta_zeros_percent = _compute_zeros_percentage(weights_delta)

    if has_non_finite_weight > 0:
        client_weight = tf.constant(0, dtype=tf.float32)
    elif client_weight_fn is None:
        client_weight = tf.cast(num_examples, dtype=tf.float32)
    else:
        client_weight = client_weight_fn(model_output)

    # Compute the L2 norm of the difference between corrected/uncorrected deltas.
    weights_delta_uncorrected = tf.nest.map_structure(
        lambda a, b: a - b, model_weights.trainable, initial_weights.trainable)
    weights_delta_correction = _compute_l2_difference(
        weights_delta_uncorrected, updates.weights_delta)
    additional_output = collections.OrderedDict(
        model_delta_zeros_percent=weights_delta_zeros_percent,
        model_delta_correction_l2_norm=weights_delta_correction,
    )

    return ClientOutput(weights_delta=weights_delta,
                        client_weight=client_weight,
                        model_output=model_output,
                        optimizer_output=optimizer_output,
                        additional_output=additional_output)
예제 #12
0
  def client_update(model,
                    dataset,
                    initial_weights,
                    client_optimizer,
                    client_weight_fn=None):
    """Updates client model.

    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      initial_weights: A `tff.learning.ModelWeights` from server.
      client_optimizer: A `tf.keras.optimizer.Optimizer` object.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor that provides the
        weight in the federated average of model deltas. If not provided, the
        default is the total number of examples processed on device.

    Returns:
      A 'ClientOutput`.
    """

    model_weights = _get_weights(model)
    tff.utils.assign(model_weights, initial_weights)

    num_examples = tf.constant(0, dtype=tf.int32)
    # Need to replace names of following two variables.
    m_states = tf.nest.map_structure(tf.zeros_like, model_weights.trainable)
    n_states = tf.nest.map_structure(tf.zeros_like, model_weights.trainable)
    for batch in iter(dataset):
      with tf.GradientTape() as tape:
        output = model.forward_pass(batch)
      grads = tape.gradient(output.loss, model_weights.trainable)
      grads_and_vars = zip(grads, model_weights.trainable)
      client_optimizer.apply_gradients(grads_and_vars)
      num_examples += tf.shape(output.predictions)[0]
      client_opt_beta = _get_optimizer_momentum_beta(client_optimizer)
      client_opt_preconditioner = _get_optimizer_preconditioner(
          client_optimizer, model_weights)

      m_states = tf.nest.map_structure(
          lambda m, p, b=client_opt_beta: b * m + (1 - b) * p,
          m_states,
          client_opt_preconditioner)
      n_states = tf.nest.map_structure(lambda m, n: m + n, m_states, n_states)

    aggregated_outputs = model.report_local_outputs()
    weights_delta = tf.nest.map_structure(
        lambda a, b, c: tf.math.divide_no_nan(a - b, c),
        model_weights.trainable, initial_weights.trainable, n_states)

    weights_delta, has_non_finite_weight = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))

    if has_non_finite_weight > 0:
      client_weight = tf.constant(0, dtype=tf.float32)
    elif client_weight_fn is None:
      client_weight = tf.cast(num_examples, dtype=tf.float32)
    else:
      client_weight = client_weight_fn(aggregated_outputs)

    return ClientOutput(
        weights_delta, client_weight, aggregated_outputs,
        collections.OrderedDict([('num_examples', num_examples),
                                 ('n_states', n_states)]))
예제 #13
0
    def client_update(model, metrics, batch_loss_fn, dataset, initial_weights,
                      client_optimizer, reconstruction_optimizer, round_num):
        """Updates client model.

    Outputted weight deltas represent the difference between final global
    variables and initial ones. The client weight (used in aggregation across
    clients) is the sum of the number of examples across all batches
    post-reconstruction (that is, only the local steps that involve updating
    global variables).

    Args:
      model: A `ReconstructionModel`.
      metrics: A List of `tf.keras.metrics.Metric`s containing metrics to be
        computed and aggregated across clients.
      batch_loss_fn: A `tf.keras.losses.Loss` used to compute batch loss on
        `BatchOutput.predictions` (y_pred) and `BatchOutput.labels` (y_true) for
        each batch during and after reconstruction.
      dataset: A 'tf.data.Dataset'.
      initial_weights: A `tff.learning.ModelWeights` containing global trainable
        and non-trainable weights from the server.
      client_optimizer: a `tf.keras.optimizers.Optimizer` for training after the
        reconstruction step.
      reconstruction_optimizer: a `tf.keras.optimizers.Optimizer` for
        reconstruction of local trainable variables.
      round_num: the federated training round number, 1-indexed.

    Returns:
      A 'reconstruction_utils.ClientOutput`.
    """
        global_model_weights = reconstruction_utils.get_global_variables(model)
        local_model_weights = reconstruction_utils.get_local_variables(model)
        tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights,
                              initial_weights)

        @tf.function
        def reconstruction_reduce_fn(num_examples_sum, batch):
            """Runs reconstruction training on local client batch."""
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)
                batch_loss = batch_loss_fn(y_true=output.labels,
                                           y_pred=output.predictions)

            gradients = tape.gradient(batch_loss,
                                      local_model_weights.trainable)
            reconstruction_optimizer.apply_gradients(
                zip(gradients, local_model_weights.trainable))

            # Update metrics if needed.
            if evaluate_reconstruction:
                for metric in metrics:
                    metric.update_state(y_true=output.labels,
                                        y_pred=output.predictions)

            return num_examples_sum + output.num_examples

        @tf.function
        def train_reduce_fn(num_examples_sum, batch):
            """Runs one step of client optimizer on local client batch."""
            if jointly_train_variables:
                all_trainable_variables = (global_model_weights.trainable +
                                           local_model_weights.trainable)
            else:
                all_trainable_variables = global_model_weights.trainable
            with tf.GradientTape() as tape:
                output = model.forward_pass(batch, training=True)
                batch_loss = batch_loss_fn(y_true=output.labels,
                                           y_pred=output.predictions)

            gradients = tape.gradient(batch_loss, all_trainable_variables)
            client_optimizer.apply_gradients(
                zip(gradients, all_trainable_variables))

            # Update each metric.
            for metric in metrics:
                metric.update_state(y_true=output.labels,
                                    y_pred=output.predictions)

            return num_examples_sum + output.num_examples

        recon_dataset, post_recon_dataset = dataset_split_fn(
            dataset, round_num)

        # If needed, do reconstruction, training the local variables while keeping
        # the global ones frozen.
        if local_model_weights.trainable:
            # Ignore output number of examples used in reconstruction, since this
            # isn't included in `client_weight`.
            recon_dataset.reduce(initial_state=tf.constant(0),
                                 reduce_func=reconstruction_reduce_fn)

        # Train the global variables, possibly jointly with local variables if
        # jointly_train_variables is True.
        num_examples_sum = post_recon_dataset.reduce(
            initial_state=tf.constant(0), reduce_func=train_reduce_fn)

        weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                              global_model_weights.trainable,
                                              initial_weights.trainable)

        # We ignore the update if the weights_delta is non finite.
        weights_delta, has_non_finite_weight = (
            tensor_utils.zero_all_if_any_non_finite(weights_delta))

        model_local_outputs = keras_utils.read_metric_variables(metrics)

        if has_non_finite_weight > 0:
            client_weight = tf.constant(0, dtype=tf.float32)
        elif client_weight_fn is None:
            client_weight = tf.cast(num_examples_sum, dtype=tf.float32)
        else:
            client_weight = client_weight_fn(model_local_outputs)

        return reconstruction_utils.ClientOutput(weights_delta, client_weight,
                                                 model_local_outputs)
예제 #14
0
def client_control(
        # Tensor/Dataset arguments that will be supplied by TFF:
        gen_inputs_ds: tf.data.Dataset,
        real_data_ds: tf.data.Dataset,
        from_server: FromServer,
        # Python arguments bound to be bound at TFF computation construction time:
        generator: tf.keras.Model,
        discriminator: tf.keras.Model,
        disc_optimizer: tf.keras.optimizers.Optimizer,
        gen_optimizer: tf.keras.optimizers.Optimizer,
        zero_disc: tf.keras.Model,
        zero_gen: tf.keras.Model,
        tau: float) -> ClientOutput:
    """The computation to run on the client, training the discriminator.

  Args:
    gen_inputs_ds: A `tf.data.Dataset` of generator_inputs.
    real_data_ds: A `tf.data.Dataset` of data from the real distribution.
    from_server: A `FromServer` object, including the current model weights.
    generator:  The generator.
    discriminator: The discriminator.
    train_discriminator_fn: A function which takes the two networks, generator
      input, and real data and trains the discriminator.

  Returns:
    A `ClientOutput` object.
  """
    tf.nest.map_structure(lambda a, b: a.assign(b), generator.weights,
                          from_server.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), discriminator.weights,
                          from_server.discriminator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), zero_gen.weights,
                          from_server.generator_weights)
    tf.nest.map_structure(lambda a, b: a.assign(b), zero_disc.weights,
                          from_server.discriminator_weights)
    num_examples = tf.constant(0)
    meta_gen = from_server.meta_gen
    meta_disc = from_server.meta_disc
    gen_inputs_and_real_data = tf.data.Dataset.zip(
        (gen_inputs_ds, real_data_ds))
    loss_fns = gan_losses.WassersteinGanLossFns()
    for gen_inputs, real_data in gen_inputs_and_real_data:
        # It's possible that real_data and gen_inputs have different batch sizes.
        # For calculating the discriminator loss, it's desirable to have equal-sized
        # contributions from both the real and fake data. Also, it's necessary if
        # using the Wasserstein gradient penalty (where a difference is taken b/w
        # the real and fake data). So here we reduce to the min batch size. This
        # also ensures num_examples properly reflects the amount of data trained on.
        min_batch_size = tf.minimum(
            tf.shape(real_data)[0],
            tf.shape(gen_inputs)[0])
        real_data = real_data[0:min_batch_size]
        gen_inputs = gen_inputs[0:min_batch_size]
        # reset the gen/discriminator values so there's no moving
        #tf.nest.map_structure(lambda a, b: a.assign(b), generator.weights,
        #from_server.generator_weights)
        #tf.nest.map_structure(lambda a, b: a.assign(b), discriminator.weights,
        #from_server.discriminator_weights)
        with tf.GradientTape() as tape_gen:
            gen_loss = loss_fns.generator_loss(generator, discriminator,
                                               gen_inputs)
            for i in range(len(generator.weights)):
                gen_loss += tau * tf.nn.l2_loss(generator.weights[i] -
                                                meta_gen[i])
        with tf.GradientTape() as tape_disc:
            disc_loss = loss_fns.discriminator_loss(generator, discriminator,
                                                    gen_inputs, real_data)
            for i in range(len(discriminator.weights)):
                disc_loss += tau * tf.nn.l2_loss(discriminator.weights[i] -
                                                 meta_disc[i])
        # get disc grads
        disc_grads = tape_disc.gradient(disc_loss, discriminator.weights)
        disc_grads_and_vars = zip(disc_grads, discriminator.weights)

        # get gen grads
        gen_grads = tape_gen.gradient(gen_loss, generator.weights)
        gen_grads_and_vars = zip(gen_grads, generator.weights)

        disc_grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v),
                                                    disc_grads,
                                                    zero_disc.weights)
        gen_grads_and_vars = tf.nest.map_structure(lambda x, v: (x, v),
                                                   gen_grads, zero_gen.weights)
        #apply the gradients
        disc_optimizer.apply_gradients(disc_grads_and_vars)
        gen_optimizer.apply_gradients(gen_grads_and_vars)

        #find the deltas
        #disc_delta = tf.nest.map_structure(tf.subtract, discriminator.weights,
        #from_server.discriminator_weights)

        #gen_delta = tf.nest.map_structure(tf.subtract, generator.weights,
        #from_server.generator_weights)
        # add to buffers
        #zero_disc = tf.nest.map_structure(lambda a, b: a + b, zero_disc, disc_delta)
        #zero_gen = tf.nest.map_structure(lambda a, b: a + b, zero_gen, gen_delta)

        num_examples += min_batch_size
    num_examples_float = tf.cast(num_examples, tf.float32)
    disc_delta = tf.nest.map_structure(
        lambda a, b: (a - b) / num_examples_float, zero_disc.weights,
        from_server.discriminator_weights)
    gen_delta = tf.nest.map_structure(
        lambda a, b: (a - b) / num_examples_float, zero_gen.weights,
        from_server.generator_weights)

    disc_delta, disc_has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(disc_delta))
    gen_delta, gen_has_non_finite_delta = (
        tensor_utils.zero_all_if_any_non_finite(gen_delta))
    update_weight = tf.cast(num_examples, tf.float32)
    # Zero out the weight if there are any non-finite values.
    # TODO(b/122071074): federated_mean might not do the right thing if
    # all clients have zero weight.
    update_weight_disc = tf.cond(tf.equal(disc_has_non_finite_delta,
                                          0), lambda: update_weight,
                                 lambda: tf.constant(0.0))
    update_weight_gen = tf.cond(tf.equal(gen_has_non_finite_delta,
                                         0), lambda: update_weight,
                                lambda: tf.constant(0.0))
    update_weight = tf.math.minimum(update_weight_disc, update_weight_gen)
    return ClientOutput(
        discriminator_weights_delta=disc_delta,
        generator_weights_delta=gen_delta,
        update_weight=update_weight,
        counters={'num_discriminator_train_examples': num_examples})