def test_evaluation_construction_calls_model_fn(self): # Assert that the the evaluation building does not call `model_fn` too many # times. `model_fn` can potentially be expensive (loading weights, # processing, etc). mock_model_fn = mock.Mock(side_effect=keras_linear_model_fn) def loss_fn(): return tf.keras.losses.MeanSquaredError() evaluation_computation.build_federated_evaluation( model_fn=mock_model_fn, loss_fn=loss_fn) # TODO(b/186451541): Reduce the number of calls to model_fn. self.assertEqual(mock_model_fn.call_count, 2)
def test_federated_reconstruction_metrics_none_loss_decreases(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() dataset_split_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs=3) evaluate = evaluation_computation.build_federated_evaluation( model_fn, loss_fn=loss_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.01), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval=' '<loss=float32>>@SERVER)') result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) eval_result = result['eval'] # Ensure loss decreases from reconstruction vs. initializing the bias to 0. # MSE is (y - 1 * x)^2 for each example, for a mean of # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3. self.assertLess(eval_result['loss'], 19.666666)
def test_federated_reconstruction_split_data(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] evaluate = evaluation_computation.build_federated_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1)) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval=' '<loss=float32,num_examples_total=float32,num_over=float32>>@SERVER)') result = evaluate( collections.OrderedDict([ ('trainable', [[[5.0]]]), ('non_trainable', []), ]), create_client_data()) eval_result = result['eval'] self.assertAlmostEqual(eval_result['num_examples_total'], 2.0) self.assertAlmostEqual(eval_result['num_over'], 1.0)
def test_evaluation_computation_custom_stateful_broadcaster_fails( self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [counters.NumExamplesCounter(), NumOverCounter(5.0)] model_weights_type = type_conversions.type_from_tensors( reconstruction_utils.get_global_variables(model_fn())) def build_custom_stateful_broadcaster( model_weights_type) -> measured_process_lib.MeasuredProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" @federated_computation.federated_computation() def test_server_initialization(): return intrinsics.federated_value(2.0, placements.SERVER) @federated_computation.federated_computation( computation_types.FederatedType(tf.float32, placements.SERVER), computation_types.FederatedType(model_weights_type, placements.SERVER), ) def stateful_broadcast(state, value): test_metrics = intrinsics.federated_value( 3.0, placements.SERVER) return measured_process_lib.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=test_metrics) return measured_process_lib.MeasuredProcess( initialize_fn=test_server_initialization, next_fn=stateful_broadcast) with self.assertRaisesRegex(TypeError, 'must be stateless'): evaluation_computation.build_federated_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1 ), broadcast_process=build_custom_stateful_broadcaster( model_weights_type=model_weights_type))
def test_evaluation_computation_custom_stateless_broadcaster( self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [counters.NumExamplesCounter(), NumOverCounter(5.0)] model_weights_type = type_conversions.type_from_tensors( reconstruction_utils.get_global_variables(model_fn())) def build_custom_stateless_broadcaster( model_weights_type) -> measured_process_lib.MeasuredProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" @federated_computation.federated_computation() def test_server_initialization(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation( computation_types.FederatedType((), placements.SERVER), computation_types.FederatedType(model_weights_type, placements.SERVER), ) def stateless_broadcast(state, value): test_metrics = intrinsics.federated_value( 3.0, placements.SERVER) return measured_process_lib.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=test_metrics) return measured_process_lib.MeasuredProcess( initialize_fn=test_server_initialization, next_fn=stateless_broadcast) evaluate = evaluation_computation.build_federated_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), broadcast_process=build_custom_stateless_broadcaster( model_weights_type=model_weights_type)) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> <broadcast=float32,eval=' '<loss=float32,num_examples=int64,num_over=float32>>@SERVER)') result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) self.assertEqual(result['broadcast'], 3.0)
def test_federated_reconstruction_skip_recon(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] def dataset_split_fn(client_dataset): return client_dataset.repeat(0), client_dataset.repeat(2) evaluate = evaluation_computation.build_federated_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> <broadcast=<>,eval=' '<loss=float32,num_examples_total=float32,num_over=float32>>@SERVER)') result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) eval_result = result['eval'] # Now have an expectation for loss since the local bias is initialized at 0 # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3 self.assertAlmostEqual(eval_result['loss'], 19.666666) self.assertAlmostEqual(eval_result['num_examples_total'], 12.0) self.assertAlmostEqual(eval_result['num_over'], 6.0)