Ejemplo n.º 1
0
    def client_delta_tf(tf_dataset, initial_model_weights, round_num):
        """Performs client local model optimization.

    Args:
      tf_dataset: a `tf.data.Dataset` that provides training examples.
      initial_model_weights: a `tff.learning.ModelWeights` containing the
          starting global trainable and non_trainable weights.
      round_num: the federated training round number, 1-indexed.

    Returns:
      A `ClientOutput`.
    """
        model = model_fn()
        client_optimizer = client_optimizer_fn()
        reconstruction_optimizer = reconstruction_optimizer_fn()

        metrics = []
        if metrics_fn is not None:
            metrics.extend(metrics_fn())
        # To be used to calculate example-weighted mean across batches and clients.
        metrics.append(keras_utils.MeanLossMetric(loss_fn()))
        # To be used to calculate batch loss for model updates.
        batch_loss_fn = loss_fn()

        return client_update(model, metrics, batch_loss_fn, tf_dataset,
                             initial_model_weights, client_optimizer,
                             reconstruction_optimizer, round_num)
    def test_mean_loss_metric_from_keras_loss(self):
        mse_loss = tf.keras.losses.MeanSquaredError()
        mse_metric = keras_utils.MeanLossMetric(mse_loss)

        y_true = tf.ones([10, 1], dtype=tf.float32)
        y_pred = tf.ones([10, 1], dtype=tf.float32) * 0.5

        mse_metric.update_state(y_true, y_pred)
        self.assertEqual(mse_loss(y_true, y_pred), mse_metric.result())
  def client_computation(incoming_model_weights, client_dataset):
    """Reconstructs and evaluates with `incoming_model_weights`."""
    client_model = model_fn()
    client_global_weights = reconstruction_utils.get_global_variables(
        client_model)
    client_local_weights = reconstruction_utils.get_local_variables(
        client_model)
    metrics = [keras_utils.MeanLossMetric(loss_fn())]
    if metrics_fn is not None:
      metrics.extend(metrics_fn())
    batch_loss_fn = loss_fn()
    reconstruction_optimizer = reconstruction_optimizer_fn()

    @tf.function
    def reconstruction_reduce_fn(num_examples_sum, batch):
      """Runs reconstruction training on local client batch."""
      with tf.GradientTape() as tape:
        output = client_model.forward_pass(batch, training=True)
        batch_loss = batch_loss_fn(
            y_true=output.labels, y_pred=output.predictions)

      gradients = tape.gradient(batch_loss, client_local_weights.trainable)
      reconstruction_optimizer.apply_gradients(
          zip(gradients, client_local_weights.trainable))

      return num_examples_sum + output.num_examples

    @tf.function
    def evaluation_reduce_fn(num_examples_sum, batch):
      """Runs evaluation on client batch without training."""
      output = client_model.forward_pass(batch, training=False)
      # Update each metric.
      for metric in metrics:
        metric.update_state(y_true=output.labels, y_pred=output.predictions)
      return num_examples_sum + output.num_examples

    @tf.function
    def tf_client_computation(incoming_model_weights, client_dataset):
      """Reconstructs and evaluates with `incoming_model_weights`."""
      # Pass in fixed 0 round number during evaluation, since global variables
      # aren't being iteratively updated as in training.
      recon_dataset, eval_dataset = dataset_split_fn(
          client_dataset, tf.constant(0, dtype=tf.int64))

      # Assign incoming global weights to `client_model` before reconstruction.
      tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights,
                            incoming_model_weights)

      recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn)
      eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn)

      eval_local_outputs = keras_utils.read_metric_variables(metrics)
      return eval_local_outputs

    return tf_client_computation(incoming_model_weights, client_dataset)
    def test_mean_loss_metric_from_fn(self):
        """Ensures the mean loss metric also works with a callable."""
        def mse_loss(y_true, y_pred):
            return tf.reduce_mean(tf.square(y_true - y_pred))

        mse_metric = keras_utils.MeanLossMetric(mse_loss)

        y_true = tf.ones([10, 1], dtype=tf.float32)
        y_pred = tf.ones([10, 1], dtype=tf.float32) * 0.5

        mse_metric.update_state(y_true, y_pred)
        self.assertEqual(mse_loss(y_true, y_pred), mse_metric.result())
    def test_recreate_mean_loss_from_keras_loss(self):
        """Ensures we can create a metric from config, as is done in aggregation."""
        mse_loss = tf.keras.losses.MeanSquaredError()
        mse_metric = keras_utils.MeanLossMetric(mse_loss)
        recreated_mse_metric = type(mse_metric).from_config(
            mse_metric.get_config())

        y_true = tf.ones([10, 1], dtype=tf.float32)
        y_pred = tf.ones([10, 1], dtype=tf.float32) * 0.5

        mse_metric.update_state(y_true, y_pred)
        recreated_mse_metric.update_state(y_true, y_pred)

        self.assertEqual(recreated_mse_metric.result(), mse_metric.result())
    def test_mean_loss_metric_multiple_weighted_batches(self):
        mse_loss = tf.keras.losses.MeanSquaredError()
        mse_metric = keras_utils.MeanLossMetric(mse_loss)

        y_true = tf.ones([10, 1], dtype=tf.float32)
        y_pred = tf.ones([10, 1], dtype=tf.float32) * 0.5
        mse_metric.update_state(y_true, y_pred)

        y_true = tf.ones([40, 1], dtype=tf.float32)
        y_pred = tf.ones([40, 1], dtype=tf.float32)
        mse_metric.update_state(y_true, y_pred)

        # Final weighted loss is (10 * 0.5^2 + 40 * 0.0) / 50
        self.assertEqual(mse_metric.result(), 0.05)
    def test_recreate_mean_loss_from_fn(self):
        def mse_loss(y_true, y_pred):
            return tf.reduce_mean(tf.square(y_true - y_pred))

        mse_metric = keras_utils.MeanLossMetric(mse_loss)
        recreated_mse_metric = type(mse_metric).from_config(
            mse_metric.get_config())

        y_true = tf.ones([10, 1], dtype=tf.float32)
        y_pred = tf.ones([10, 1], dtype=tf.float32) * 0.5

        mse_metric.update_state(y_true, y_pred)
        recreated_mse_metric.update_state(y_true, y_pred)

        self.assertEqual(recreated_mse_metric.result(), mse_metric.result())
def build_federated_reconstruction_evaluation(
    model_fn: ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: LossFn,
    metrics_fn: Optional[MetricsFn],
    reconstruction_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None
) -> tff.Computation:
  """Builds a `tff.Computation` for evaluation of a `ReconstructionModel`.

  The returned computation proceeds in two stages: (1) reconstruction and (2)
  evaluation. During the reconstruction stage, local variables are reconstructed
  by freezing global variables and training using reconstruction_optimizer_fn.
  During the evaluation stage, the reconstructed local variables and global
  variables are evaluated using the provided loss_fn and metrics_fn.

  Usage of returned computation:
    eval_comp = build_federated_reconstruction_evaluation(...)
    metrics = eval_comp(reconstruction_utils.get_global_variables(model),
                        federated_data)

  Args:
    model_fn: A no-arg function that returns a `ReconstructionModel`. This
      method must *not* capture Tensorflow tensors or variables and use them.
      Must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      evaluate the model. The loss will be applied to the model's outputs during
      the evaluation stage. The final loss metric is the example-weighted mean
      loss across batches (and across clients).
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. The metrics will be applied to the model's outputs
      during the evaluation stage. Final metric values are the example-weighted
      mean of metric values across batches (and across clients). If None, no
      metrics are applied.
    reconstruction_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` used to reconstruct the local variables
      with the global ones frozen.
    dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client
      dataset and round number (always 0 for evaluation) and producing two TF
      datasets. The first is iterated over during reconstruction, and the second
      is iterated over during evaluation. This can be used to preprocess
      datasets to e.g. iterate over them for multiple epochs or use disjoint
      data for reconstruction and evaluation. If None, split client data in half
      for each user, using one half for reconstruction and the other for
      evaluation. See `reconstruction_utils.build_dataset_split_fn` for options.

  Raises:
    ValueError: if both `loss_fn` and `metrics_fn` are None.

  Returns:
    A `tff.Computation` that accepts model parameters and federated data and
    returns example-weighted evaluation loss and metrics.
  """
  # Construct the model first just to obtain the metadata and define all the
  # types needed to define the computations that follow.
  with tf.Graph().as_default():
    model = model_fn()
    global_weights = reconstruction_utils.get_global_variables(model)
    model_weights_type = tff.framework.type_from_tensors(global_weights)
    batch_type = tff.to_type(model.input_spec)
    metrics = [keras_utils.MeanLossMetric(loss_fn())]
    if metrics_fn is not None:
      metrics.extend(metrics_fn())
    if not metrics:
      raise ValueError(
          'One or both of metrics_fn and loss_fn should be provided.')
    federated_output_computation = (
        keras_utils.federated_output_computation_from_metrics(metrics))
    # Remove unneeded variables to avoid polluting namespace.
    del model
    del global_weights
    del metrics

  if dataset_split_fn is None:
    dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
        split_dataset=True)

  @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
  def client_computation(incoming_model_weights, client_dataset):
    """Reconstructs and evaluates with `incoming_model_weights`."""
    client_model = model_fn()
    client_global_weights = reconstruction_utils.get_global_variables(
        client_model)
    client_local_weights = reconstruction_utils.get_local_variables(
        client_model)
    metrics = [keras_utils.MeanLossMetric(loss_fn())]
    if metrics_fn is not None:
      metrics.extend(metrics_fn())
    batch_loss_fn = loss_fn()
    reconstruction_optimizer = reconstruction_optimizer_fn()

    @tf.function
    def reconstruction_reduce_fn(num_examples_sum, batch):
      """Runs reconstruction training on local client batch."""
      with tf.GradientTape() as tape:
        output = client_model.forward_pass(batch, training=True)
        batch_loss = batch_loss_fn(
            y_true=output.labels, y_pred=output.predictions)

      gradients = tape.gradient(batch_loss, client_local_weights.trainable)
      reconstruction_optimizer.apply_gradients(
          zip(gradients, client_local_weights.trainable))

      return num_examples_sum + output.num_examples

    @tf.function
    def evaluation_reduce_fn(num_examples_sum, batch):
      """Runs evaluation on client batch without training."""
      output = client_model.forward_pass(batch, training=False)
      # Update each metric.
      for metric in metrics:
        metric.update_state(y_true=output.labels, y_pred=output.predictions)
      return num_examples_sum + output.num_examples

    @tf.function
    def tf_client_computation(incoming_model_weights, client_dataset):
      """Reconstructs and evaluates with `incoming_model_weights`."""
      # Pass in fixed 0 round number during evaluation, since global variables
      # aren't being iteratively updated as in training.
      recon_dataset, eval_dataset = dataset_split_fn(
          client_dataset, tf.constant(0, dtype=tf.int64))

      # Assign incoming global weights to `client_model` before reconstruction.
      tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights,
                            incoming_model_weights)

      recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn)
      eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn)

      eval_local_outputs = keras_utils.read_metric_variables(metrics)
      return eval_local_outputs

    return tf_client_computation(incoming_model_weights, client_dataset)

  @tff.federated_computation(
      tff.type_at_server(model_weights_type),
      tff.type_at_clients(tff.SequenceType(batch_type)))
  def server_eval(server_model_weights, federated_dataset):
    client_outputs = tff.federated_map(
        client_computation,
        [tff.federated_broadcast(server_model_weights), federated_dataset])
    return federated_output_computation(client_outputs)

  return server_eval
Ejemplo n.º 9
0
def build_federated_reconstruction_process(
    model_fn: ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: LossFn,
    metrics_fn: Optional[MetricsFn] = None,
    server_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 1.0),
    client_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    reconstruction_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, 0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None,
    evaluate_reconstruction: bool = False,
    jointly_train_variables: bool = False,
    client_weight_fn: Optional[ClientWeightFn] = None,
    aggregation_factory: Optional[
        tff.aggregators.WeightedAggregationFactory] = None,
) -> tff.templates.IterativeProcess:
    """Builds the IterativeProcess for optimization using FedRecon.

  Returns a `tff.templates.IterativeProcess` for Federated Reconstruction. On
  the client, computation can be divided into two stages: (1) reconstruction of
  local variables and (2) training of global variables (possibly jointly with
  reconstructed local variables).

  Args:
    model_fn: A no-arg function that returns a `ReconstructionModel`. This
      method must *not* capture Tensorflow tensors or variables and use them.
      must be constructed entirely from scratch on each invocation, returning
      the same pre-constructed model each call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      compute local model updates during reconstruction and post-reconstruction
      and evaluate the model during training. The final loss metric is the
      example-weighted mean loss across batches and across clients. Depending on
      whether `evaluate_reconstruction` is True, the loss metric may or may not
      include reconstruction batches in the loss.
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. Metrics results are computed locally as described
      by the metric, and are aggregated across clients as in
      `federated_aggregate_keras_metric`. If None, no metrics are applied.
      Depending on whether evaluate_reconstruction is True, metrics may or may
      not be computed on reconstruction batches as well.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for applying updates to the global model
      on the server.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` for local client training after
      reconstruction.
    reconstruction_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` used to reconstruct the local variables,
      with the global ones frozen, or the first stage described above.
    dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client
      dataset and training round number (1-indexed) and producing two TF
      datasets. The first is iterated over during reconstruction, and the second
      is iterated over post-reconstruction. This can be used to preprocess
      datasets to e.g. iterate over them for multiple epochs or use disjoint
      data for reconstruction and post-reconstruction. If None,
      `reconstruction_utils.simple_dataset_split_fn` is used, which results in
      iterating over the original client data for both phases of training. See
      `reconstruction_utils.build_dataset_split_fn` for options.
    evaluate_reconstruction: If True, metrics (including loss) are computed on
      batches during reconstruction and post-reconstruction. If False, metrics
      are computed on batches only post-reconstruction, when global weights are
      being updated. Note that metrics are aggregated across batches as given by
      the metric (example-weighted mean for the loss). Setting this to True
      includes all local batches in metric calculations. Setting this to False
      brings the interpretation of these metrics closer to the interpretation of
      metrics in FedAvg. Note that this does not affect training at all: losses
        for individual batches are calculated and used to update variables
        regardless.
    jointly_train_variables: Whether to train local variables during the second
      stage described above. If True, global and local variables are trained
      jointly after reconstruction of local variables using the optimizer given
      by client_optimizer_fn. If False, only global variables are trained during
      the second stage with local variables frozen, similar to alternating
      minimization.
    client_weight_fn: Optional function that takes the local model's output,
      and returns a tensor that provides the weight in the federated average of
      model deltas. If not provided, the default is the total number of examples
      processed on device during post-reconstruction phase.
    aggregation_factory: An optional instance of
      `tff.aggregators.WeightedAggregationFactory` determining the method of
      aggregation to perform. If unspecified, uses a default
      `tff.aggregators.MeanFactory` which computes a stateless weighted mean
      across clients.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        throwaway_model_for_metadata = model_fn()

    model_weights_type = tff.framework.type_from_tensors(
        reconstruction_utils.get_global_variables(
            throwaway_model_for_metadata))

    aggregation_process = _instantiate_aggregation_process(
        aggregation_factory, model_weights_type, client_weight_fn)
    aggregator_state_type = (
        aggregation_process.initialize.type_signature.result.member)

    server_init_tff = build_server_init_fn(model_fn, server_optimizer_fn,
                                           aggregation_process)
    server_state_type = server_init_tff.type_signature.result.member

    server_update_fn = build_server_update_fn(
        model_fn,
        server_optimizer_fn,
        server_state_type,
        server_state_type.model,
        aggregator_state_type=aggregator_state_type)

    tf_dataset_type = tff.SequenceType(throwaway_model_for_metadata.input_spec)
    if dataset_split_fn is None:
        dataset_split_fn = reconstruction_utils.simple_dataset_split_fn
    client_update_fn = build_client_update_fn(
        model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        tf_dataset_type=tf_dataset_type,
        model_weights_type=server_state_type.model,
        client_optimizer_fn=client_optimizer_fn,
        reconstruction_optimizer_fn=reconstruction_optimizer_fn,
        dataset_split_fn=dataset_split_fn,
        evaluate_reconstruction=evaluate_reconstruction,
        jointly_train_variables=jointly_train_variables,
        client_weight_fn=client_weight_fn)

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(tf_dataset_type)
    # Create placeholder metrics to produce a corresponding federated output
    # computation.
    metrics = []
    if metrics_fn is not None:
        metrics.extend(metrics_fn())
    metrics.append(keras_utils.MeanLossMetric(loss_fn()))
    federated_output_computation = (
        keras_utils.federated_output_computation_from_metrics(metrics))

    run_one_round_tff = build_run_one_round_fn(
        server_update_fn,
        client_update_fn,
        federated_output_computation,
        federated_server_state_type,
        federated_dataset_type,
        aggregation_process=aggregation_process,
    )

    iterative_process = tff.templates.IterativeProcess(
        initialize_fn=server_init_tff, next_fn=run_one_round_tff)

    @tff.tf_computation(server_state_type)
    def get_model_weights(server_state):
        return server_state.model

    iterative_process.get_model_weights = get_model_weights
    return iterative_process