Exemplo n.º 1
0
  def test_evaluation_construction_calls_model_fn(self):
    # Assert that the the evaluation building does not call `model_fn` too many
    # times. `model_fn` can potentially be expensive (loading weights,
    # processing, etc).
    mock_model_fn = mock.Mock(side_effect=keras_linear_model_fn)

    def loss_fn():
      return tf.keras.losses.MeanSquaredError()

    evaluation_computation.build_federated_evaluation(
        model_fn=mock_model_fn, loss_fn=loss_fn)
    # TODO(b/186451541): Reduce the number of calls to model_fn.
    self.assertEqual(mock_model_fn.call_count, 2)
Exemplo n.º 2
0
  def test_federated_reconstruction_metrics_none_loss_decreases(self, model_fn):

    def loss_fn():
      return tf.keras.losses.MeanSquaredError()

    dataset_split_fn = reconstruction_utils.build_dataset_split_fn(
        recon_epochs=3)

    evaluate = evaluation_computation.build_federated_evaluation(
        model_fn,
        loss_fn=loss_fn,
        reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.01),
        dataset_split_fn=dataset_split_fn)
    self.assertEqual(
        str(evaluate.type_signature),
        '(<server_model_weights=<trainable=<float32[1,1]>,'
        'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
        'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval='
        '<loss=float32>>@SERVER)')

    result = evaluate(
        collections.OrderedDict([
            ('trainable', [[[1.0]]]),
            ('non_trainable', []),
        ]), create_client_data())
    eval_result = result['eval']

    # Ensure loss decreases from reconstruction vs. initializing the bias to 0.
    # MSE is (y - 1 * x)^2 for each example, for a mean of
    # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3.
    self.assertLess(eval_result['loss'], 19.666666)
Exemplo n.º 3
0
  def test_federated_reconstruction_split_data(self, model_fn):

    def loss_fn():
      return tf.keras.losses.MeanSquaredError()

    def metrics_fn():
      return [NumExamplesCounter(), NumOverCounter(5.0)]

    evaluate = evaluation_computation.build_federated_evaluation(
        model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
    self.assertEqual(
        str(evaluate.type_signature),
        '(<server_model_weights=<trainable=<float32[1,1]>,'
        'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
        'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval='
        '<loss=float32,num_examples_total=float32,num_over=float32>>@SERVER)')

    result = evaluate(
        collections.OrderedDict([
            ('trainable', [[[5.0]]]),
            ('non_trainable', []),
        ]), create_client_data())
    eval_result = result['eval']

    self.assertAlmostEqual(eval_result['num_examples_total'], 2.0)
    self.assertAlmostEqual(eval_result['num_over'], 1.0)
Exemplo n.º 4
0
    def test_evaluation_computation_custom_stateful_broadcaster_fails(
            self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [counters.NumExamplesCounter(), NumOverCounter(5.0)]

        model_weights_type = type_conversions.type_from_tensors(
            reconstruction_utils.get_global_variables(model_fn()))

        def build_custom_stateful_broadcaster(
                model_weights_type) -> measured_process_lib.MeasuredProcess:
            """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`."""
            @federated_computation.federated_computation()
            def test_server_initialization():
                return intrinsics.federated_value(2.0, placements.SERVER)

            @federated_computation.federated_computation(
                computation_types.FederatedType(tf.float32, placements.SERVER),
                computation_types.FederatedType(model_weights_type,
                                                placements.SERVER),
            )
            def stateful_broadcast(state, value):
                test_metrics = intrinsics.federated_value(
                    3.0, placements.SERVER)
                return measured_process_lib.MeasuredProcessOutput(
                    state=state,
                    result=intrinsics.federated_broadcast(value),
                    measurements=test_metrics)

            return measured_process_lib.MeasuredProcess(
                initialize_fn=test_server_initialization,
                next_fn=stateful_broadcast)

        with self.assertRaisesRegex(TypeError, 'must be stateless'):
            evaluation_computation.build_federated_evaluation(
                model_fn,
                loss_fn=loss_fn,
                metrics_fn=metrics_fn,
                reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1
                                                                            ),
                broadcast_process=build_custom_stateful_broadcaster(
                    model_weights_type=model_weights_type))
Exemplo n.º 5
0
    def test_evaluation_computation_custom_stateless_broadcaster(
            self, model_fn):
        def loss_fn():
            return tf.keras.losses.MeanSquaredError()

        def metrics_fn():
            return [counters.NumExamplesCounter(), NumOverCounter(5.0)]

        model_weights_type = type_conversions.type_from_tensors(
            reconstruction_utils.get_global_variables(model_fn()))

        def build_custom_stateless_broadcaster(
                model_weights_type) -> measured_process_lib.MeasuredProcess:
            """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`."""
            @federated_computation.federated_computation()
            def test_server_initialization():
                return intrinsics.federated_value((), placements.SERVER)

            @federated_computation.federated_computation(
                computation_types.FederatedType((), placements.SERVER),
                computation_types.FederatedType(model_weights_type,
                                                placements.SERVER),
            )
            def stateless_broadcast(state, value):
                test_metrics = intrinsics.federated_value(
                    3.0, placements.SERVER)
                return measured_process_lib.MeasuredProcessOutput(
                    state=state,
                    result=intrinsics.federated_broadcast(value),
                    measurements=test_metrics)

            return measured_process_lib.MeasuredProcess(
                initialize_fn=test_server_initialization,
                next_fn=stateless_broadcast)

        evaluate = evaluation_computation.build_federated_evaluation(
            model_fn,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
            broadcast_process=build_custom_stateless_broadcaster(
                model_weights_type=model_weights_type))
        self.assertEqual(
            str(evaluate.type_signature),
            '(<server_model_weights=<trainable=<float32[1,1]>,'
            'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
            'y=float32[?,1]>*}@CLIENTS> -> <broadcast=float32,eval='
            '<loss=float32,num_examples=int64,num_over=float32>>@SERVER)')

        result = evaluate(
            collections.OrderedDict([
                ('trainable', [[[1.0]]]),
                ('non_trainable', []),
            ]), create_client_data())

        self.assertEqual(result['broadcast'], 3.0)
Exemplo n.º 6
0
  def test_federated_reconstruction_skip_recon(self, model_fn):

    def loss_fn():
      return tf.keras.losses.MeanSquaredError()

    def metrics_fn():
      return [NumExamplesCounter(), NumOverCounter(5.0)]

    def dataset_split_fn(client_dataset):
      return client_dataset.repeat(0), client_dataset.repeat(2)

    evaluate = evaluation_computation.build_federated_evaluation(
        model_fn,
        loss_fn=loss_fn,
        metrics_fn=metrics_fn,
        reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
        dataset_split_fn=dataset_split_fn)
    self.assertEqual(
        str(evaluate.type_signature),
        '(<server_model_weights=<trainable=<float32[1,1]>,'
        'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],'
        'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval='
        '<loss=float32,num_examples_total=float32,num_over=float32>>@SERVER)')

    result = evaluate(
        collections.OrderedDict([
            ('trainable', [[[1.0]]]),
            ('non_trainable', []),
        ]), create_client_data())
    eval_result = result['eval']

    # Now have an expectation for loss since the local bias is initialized at 0
    # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean
    # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3
    self.assertAlmostEqual(eval_result['loss'], 19.666666)
    self.assertAlmostEqual(eval_result['num_examples_total'], 12.0)
    self.assertAlmostEqual(eval_result['num_over'], 6.0)