コード例 #1
0
 def test_has_only_global_variables_true(self):
     keras_model = tff.simulation.models.mnist.create_keras_model(
         compile_model=False)
     input_spec = _create_input_spec()
     model = keras_utils.from_keras_model(keras_model=keras_model,
                                          global_layers=keras_model.layers,
                                          local_layers=[],
                                          input_spec=input_spec)
     self.assertTrue(reconstruction_utils.has_only_global_variables(model))
コード例 #2
0
def build_federated_finetune_evaluation(
    model_fn: ModelFn,
    *,  # Callers pass below args by name.
    loss_fn: LossFn,
    metrics_fn: Optional[MetricsFn] = None,
    finetune_optimizer_fn: OptimizerFn = functools.partial(
        tf.keras.optimizers.SGD, learning_rate=0.1),
    dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None
) -> tff.Computation:
    """Builds a computation for evaluating a fully global `ReconstructionModel`.

  The input `model_fn` must return a `ReconstructionModel` that has only global
  variables. The returned computation proceeds in two stages on every client:
  (1) fine-tuning and (2) evaluation. During the fine-tuning stage, all global
  variables are fine-tuned on the first `tf.data.Dataset` returned by
  `dataset_split_fn` using finetune_optimizer_fn. During the evaluation stage,
  the fine-tuned model is evaluated on the second `tf.data.Dataset` returned by
  `dataset_split_fn`.

  Usage of returned computation:
    eval_comp = build_federated_finetune_evaluation(...)
    metrics = eval_comp(reconstruction_utils.get_global_variables(model),
                        federated_data)

  Args:
    model_fn: A no-arg function that returns a `ReconstructionModel`. The
      returned model must have only global variables. This method must *not*
      capture Tensorflow tensors or variables and use them. Must be constructed
      entirely from scratch on each invocation, returning the same model each
      call will result in an error.
    loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
      evaluate the model. The loss will be applied to the model's outputs during
      the evaluation stage. The final loss metric is the example-weighted mean
      loss across batches (and across clients).
    metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s
      to evaluate the model. The metrics will be applied to the model's outputs
      during the evaluation stage. Final metric values are the example-weighted
      mean of metric values across batches (and across clients). If None, no
      metrics are applied.
    finetune_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer` used to fine-tune the global variables. A
      learning rate of zero means no fine-tuning.
    dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client
      dataset and round number (always 0 for evaluation) and producing two TF
      datasets. The first is iterated over during fine-tuning, and the second is
      iterated over during evaluation. If None, split client data in half for
      each user, using even-indexed entries for fine-tuning and odd-indexed
      entries for evaluation. See
      `federated_trainer_utils.build_dataset_split_fn` for options.

  Raises:
    ValueError: if `model_fn` returns a model with local variables.

  Returns:
    A `tff.Computation` that accepts model parameters and federated data and
    returns example-weighted evaluation loss and metrics.
  """
    # Construct the model first just to obtain the metadata and define all the
    # types needed to define the computations that follow.
    with tf.Graph().as_default():
        model = model_fn()
        global_weights = reconstruction_utils.get_global_variables(model)
        if not reconstruction_utils.has_only_global_variables(model):
            raise ValueError(
                '`model_fn` should return a model with only global variables.')
        model_weights_type = tff.framework.type_from_tensors(global_weights)
        batch_type = tff.to_type(model.input_spec)
        metrics = [keras_utils.MeanLossMetric(loss_fn())]
        if metrics_fn is not None:
            metrics.extend(metrics_fn())
        federated_output_computation = (
            keras_utils.federated_output_computation_from_metrics(metrics))
        # Remove unneeded variables to avoid polluting namespace.
        del model
        del global_weights
        del metrics

    if dataset_split_fn is None:
        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            split_dataset=True)

    @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type))
    def client_computation(incoming_model_weights, client_dataset):
        """Fine-tunes and evaluates with `incoming_model_weights`."""
        client_model = model_fn()
        client_global_weights = reconstruction_utils.get_global_variables(
            client_model)
        metrics = [keras_utils.MeanLossMetric(loss_fn())]
        if metrics_fn is not None:
            metrics.extend(metrics_fn())
        batch_loss_fn = loss_fn()
        finetune_optimizer = finetune_optimizer_fn()

        @tf.function
        def finetune_reduce_fn(num_examples_sum, batch):
            """Fine-tunes the model on local client batch."""
            with tf.GradientTape() as tape:
                output = client_model.forward_pass(batch, training=True)
                batch_loss = batch_loss_fn(y_true=output.labels,
                                           y_pred=output.predictions)

            gradients = tape.gradient(batch_loss,
                                      client_global_weights.trainable)
            finetune_optimizer.apply_gradients(
                zip(gradients, client_global_weights.trainable))

            return num_examples_sum + output.num_examples

        @tf.function
        def evaluation_reduce_fn(num_examples_sum, batch):
            """Runs evaluation on client batch without training."""
            output = client_model.forward_pass(batch, training=False)
            # Update each metric.
            for metric in metrics:
                metric.update_state(y_true=output.labels,
                                    y_pred=output.predictions)
            return num_examples_sum + output.num_examples

        @tf.function
        def tf_client_computation(incoming_model_weights, client_dataset):
            """Fine-tunes and evaluates with `incoming_model_weights`."""
            # Pass in fixed 0 round number during evaluation.
            finetune_dataset, eval_dataset = dataset_split_fn(
                client_dataset, tf.constant(0, dtype=tf.int64))

            # Assign incoming global weights to `client_model` before fine-tuning.
            tf.nest.map_structure(lambda v, t: v.assign(t),
                                  client_global_weights,
                                  incoming_model_weights)

            finetune_dataset.reduce(tf.constant(0), finetune_reduce_fn)
            eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn)

            eval_local_outputs = keras_utils.read_metric_variables(metrics)
            return eval_local_outputs

        return tf_client_computation(incoming_model_weights, client_dataset)

    @tff.federated_computation(tff.type_at_server(model_weights_type),
                               tff.type_at_clients(
                                   tff.SequenceType(batch_type)))
    def server_eval(server_model_weights, federated_dataset):
        client_outputs = tff.federated_map(
            client_computation,
            [tff.federated_broadcast(server_model_weights), federated_dataset])
        return federated_output_computation(client_outputs)

    return server_eval
コード例 #3
0
    def iterative_process_builder(
        model_fn: Callable[[], reconstruction_model.ReconstructionModel],
        loss_fn: Callable[[], List[tf.keras.losses.Loss]],
        metrics_fn: Optional[Callable[[],
                                      List[tf.keras.metrics.Metric]]] = None,
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
        dataset_split_fn_builder: Callable[
            ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils.
        build_dataset_split_fn,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    For a `stackoverflow_nwp_finetune` task, the `model_fn` must return a model
    that has only global variables, and the argument `dataset_split_fn_builder`
    is ignored. The returned iterative process is basically the same as the one
    created by the standard `tff.learning.build_federated_averaging_process`.

    For other tasks, the returned iterative process performs the federated
    reconstruction algorithm defined by
    `training_process.build_federated_reconstruction_process`.

    Args:
      model_fn: A no-arg function returning a
        `reconstruction_model.ReconstructionModel`. The returned model must have
        only global variables for a `stackoverflow_nwp_finetune` task.
      loss_fn: A no-arg function returning a list of `tf.keras.losses.Loss`.
      metrics_fn: A no-arg function returning a list of
        `tf.keras.metrics.Metric`.
      client_weight_fn: Optional function that takes the local model's output,
        and returns a tensor that provides the weight in the federated average
        of model deltas. If not provided, the default is the total number of
        examples processed on device. If DP is used, this argument is ignored,
        and uniform client weighting is used.
      dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used
        to split the examples into a reconstruction, and post-reconstruction
        set. Ignored for a `stackoverflow_nwp_finetune` task.

    Raises:
      ValueError: if `model_fn` returns a model with local variables for a
        `stackoverflow_nwp_finetune` task.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        # Get aggregation factory for DP, if needed.
        aggregation_factory = None
        client_weighting = client_weight_fn
        if FLAGS.dp_noise_multiplier is not None:
            aggregation_factory = tff.learning.dp_aggregator(
                noise_multiplier=FLAGS.dp_noise_multiplier,
                clients_per_round=float(FLAGS.clients_per_round),
                zeroing=FLAGS.dp_zeroing)
            # DP is only implemented for uniform weighting.
            client_weighting = lambda _: 1.0

        if FLAGS.task == 'stackoverflow_nwp_finetune':

            if not reconstruction_utils.has_only_global_variables(model_fn()):
                raise ValueError(
                    '`model_fn` should return a model with only global variables. '
                )

            def fake_dataset_split_fn(
                client_dataset: tf.data.Dataset, round_num: tf.Tensor
            ) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
                del round_num
                return client_dataset.repeat(0), client_dataset.repeat(
                    FLAGS.client_epochs_per_round)

            return training_process.build_federated_reconstruction_process(
                model_fn=model_fn,
                loss_fn=loss_fn,
                metrics_fn=metrics_fn,
                server_optimizer_fn=lambda: server_optimizer_fn(
                    FLAGS.server_learning_rate),
                client_optimizer_fn=lambda: client_optimizer_fn(
                    FLAGS.client_learning_rate),
                dataset_split_fn=fake_dataset_split_fn,
                client_weight_fn=client_weighting,
                aggregation_factory=aggregation_factory)

        return training_process.build_federated_reconstruction_process(
            model_fn=model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            server_optimizer_fn=lambda: server_optimizer_fn(
                FLAGS.server_learning_rate),
            client_optimizer_fn=lambda: client_optimizer_fn(
                FLAGS.client_learning_rate),
            reconstruction_optimizer_fn=functools.partial(
                reconstruction_optimizer_fn,
                FLAGS.reconstruction_learning_rate),
            dataset_split_fn=dataset_split_fn_builder(
                recon_epochs_max=FLAGS.recon_epochs_max,
                recon_epochs_constant=FLAGS.recon_epochs_constant,
                recon_steps_max=FLAGS.recon_steps_max,
                post_recon_epochs=FLAGS.post_recon_epochs,
                post_recon_steps_max=FLAGS.post_recon_steps_max,
                split_dataset=FLAGS.split_dataset),
            evaluate_reconstruction=FLAGS.evaluate_reconstruction,
            jointly_train_variables=FLAGS.jointly_train_variables,
            client_weight_fn=client_weighting,
            aggregation_factory=aggregation_factory)