def test_federated_reconstruction_metrics_none_loss_decreases(
            self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
            recon_epochs_max=3)

        evaluate = evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=None,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.01),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluate.type_signature),
            '(<server_model_weights=<trainable=<float32[1,1]>,'
            'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
            'y=float32[?,1]>*}@CLIENTS> -> <loss=float32>@SERVER)')

        result = evaluate(
            collections.OrderedDict([
                ('trainable', [[[1.0]]]),
                ('non_trainable', []),
            ]), create_client_data())

        expected_keys = ['loss']
        self.assertCountEqual(result.keys(), expected_keys)
        # Ensure loss decreases from reconstruction vs. initializing the bias to 0.
        # MSE is (y - 1 * x)^2 for each example, for a mean of
        # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
        self.assertLess(result['loss'], 19.666666)
    def test_federated_reconstruction_split_data(self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [NumExamplesCounter(), NumOverCounter(5.0)]

        evaluate = evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
        self.assertEqual(
            str(evaluate.type_signature),
            '(<server_model_weights=<trainable=<float32[1,1]>,'
            'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
            'y=float32[?,1]>*}@CLIENTS> -> '
            '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)'
        )

        result = evaluate(
            collections.OrderedDict([
                ('trainable', [[[5.0]]]),
                ('non_trainable', []),
            ]), create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(result.keys(), expected_keys)
        self.assertAlmostEqual(result['num_examples_total'], 2.0)
        self.assertAlmostEqual(result['num_over'], 1.0)
    def test_federated_reconstruction_skip_recon(self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [NumExamplesCounter(), NumOverCounter(5.0)]

        # Ensure reconstruction is skipped if `recon_dataset` is empty. This also
        # ensures `round_num` is 0 for evaluation and loss doesn't change if
        # `eval_dataset` is repeated.
        def dataset_split_fn(client_dataset, round_num):
            return client_dataset.repeat(round_num), client_dataset.repeat(2)

        evaluate = evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluate.type_signature),
            '(<server_model_weights=<trainable=<float32[1,1]>,'
            'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
            'y=float32[?,1]>*}@CLIENTS> -> '
            '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)'
        )

        result = evaluate(
            collections.OrderedDict([
                ('trainable', [[[1.0]]]),
                ('non_trainable', []),
            ]), create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(result.keys(), expected_keys)
        # Now have an expectation for loss since the local bias is initialized at 0
        # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean
        # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3
        self.assertAlmostEqual(result['loss'], 19.666666)
        self.assertAlmostEqual(result['num_examples_total'], 12.0)
        self.assertAlmostEqual(result['num_over'], 6.0)
    def test_federated_reconstruction_recon_lr_0(self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [NumExamplesCounter(), NumOverCounter(5.0)]

        dataset_split_fn = reconstruction_utils.build_dataset_split_fn()

        evaluate = evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            # Set recon optimizer LR to 0 so reconstruction has no effect.
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.0),
            dataset_split_fn=dataset_split_fn)
        self.assertEqual(
            str(evaluate.type_signature),
            '(<server_model_weights=<trainable=<float32[1,1]>,'
            'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
            'y=float32[?,1]>*}@CLIENTS> -> '
            '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)'
        )

        result = evaluate(
            collections.OrderedDict([
                ('trainable', [[[1.0]]]),
                ('non_trainable', []),
            ]), create_client_data())

        expected_keys = ['loss', 'num_examples_total', 'num_over']
        self.assertCountEqual(result.keys(), expected_keys)
        # Now have an expectation for loss since the local bias is initialized at 0
        # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean
        # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
        self.assertAlmostEqual(result['loss'], 19.666666)
        self.assertAlmostEqual(result['num_examples_total'], 6.0)
        self.assertAlmostEqual(result['num_over'], 3.0)
def evaluation_computation_builder(
    model_fn: Callable[[], reconstruction_model.ReconstructionModel],
    loss_fn: Callable[[], tf.losses.Loss],
    metrics_fn: Callable[[], List[tf.metrics.Metric]],
    dataset_split_fn_builder: Callable[
        ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils
    .build_dataset_split_fn,
    task_name: str = 'stackoverflow_nwp',
) -> tff.Computation:
  """Creates an evaluation computation using federated reconstruction."""

  # For a `stackoverflow_nwp_finetune` task, the first dataset returned by
  # `dataset_split_fn` is used for fine-tuning global variables. For other
  # tasks, the first dataset is used for reconstructing local variables.
  dataset_split_fn = dataset_split_fn_builder(
      recon_epochs_max=1,
      recon_epochs_constant=1,
      recon_steps_max=1,
      post_recon_epochs=1,
      post_recon_steps_max=1,
      split_dataset=True)

  if task_name == 'stackoverflow_nwp_finetune':
    return federated_evaluation.build_federated_finetune_evaluation(
        model_fn=model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        finetune_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
        dataset_split_fn=dataset_split_fn)

  return evaluation_computation.build_federated_reconstruction_evaluation(
      model_fn=model_fn,
      loss_fn=loss_fn,
      metrics_fn=metrics_fn,
      reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
      dataset_split_fn=dataset_split_fn)
示例#6
0
    def evaluation_computation_builder(
        model_fn: Callable[[], reconstruction_model.ReconstructionModel],
        loss_fn: Callable[[], tf.losses.Loss],
        metrics_fn: Callable[[], List[tf.metrics.Metric]],
        dataset_split_fn_builder: Callable[
            ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils.
        build_dataset_split_fn,
    ) -> tff.Computation:
        """Creates a `tff.Computation` for federated evaluation.

    For a `stackoverflow_nwp_finetune` task, the returned `tff.Computation` is
    created by `federated_evaluation.build_federated_finetune_evaluation`. For
    other tasks, the returned `tff.Computation` is given by
    `evaluation_computation.build_federated_reconstruction_evaluation`.

    Args:
      model_fn: A no-arg function that returns a `ReconstructionModel`. The
        returned model must have only global variables for a
        `stackoverflow_nwp_finetune` task. This method must *not* capture
        Tensorflow tensors or variables and use them. Must be constructed
        entirely from scratch on each invocation, returning the same model each
        call will result in an error.
      loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to
        evaluate the model. The final loss metric is the example-weighted mean
        loss across batches (and across clients).
      metrics_fn: A no-arg function returning a list of
        `tf.keras.metrics.Metric`s to use to evaluate the model. The final
        metrics are the example-weighted mean metrics across batches (and across
        clients).
      dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used
        to split the examples into a reconstruction set (which is used as a
        fine-tuning set for a `stackoverflow_nwp_finetune` task), and an
        evaluation set.

    Returns:
      A `tff.Computation` for federated evaluation.
    """

        # For a `stackoverflow_nwp_finetune` task, the first dataset returned by
        # `dataset_split_fn` is used for fine-tuning global variables. For other
        # tasks, the first dataset is used for reconstructing local variables.
        dataset_split_fn = dataset_split_fn_builder(
            recon_epochs_max=FLAGS.recon_epochs_max,
            recon_epochs_constant=FLAGS.recon_epochs_constant,
            recon_steps_max=FLAGS.recon_steps_max,
            post_recon_epochs=FLAGS.post_recon_epochs,
            post_recon_steps_max=FLAGS.post_recon_steps_max,
            # Getting meaningful evaluation metrics requires splitting the data.
            split_dataset=True)

        if FLAGS.task == 'stackoverflow_nwp_finetune':
            return federated_evaluation.build_federated_finetune_evaluation(
                model_fn=model_fn,
                loss_fn=loss_fn,
                metrics_fn=metrics_fn,
                finetune_optimizer_fn=functools.partial(
                    finetune_optimizer_fn, FLAGS.finetune_learning_rate),
                dataset_split_fn=dataset_split_fn)

        return evaluation_computation.build_federated_reconstruction_evaluation(
            model_fn=model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=functools.partial(
                reconstruction_optimizer_fn,
                FLAGS.reconstruction_learning_rate),
            dataset_split_fn=dataset_split_fn)