def test_build_dataset_split_fn_recon_max_steps(self): # 3 batches. client_dataset = tf.data.Dataset.range(6).batch(2) split_dataset_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=2, recon_steps_max=4) # Round number shouldn't matter. recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5], [0, 1]]) self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]]) # Adding more steps than the number of actual steps has no effect. split_dataset_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=2, recon_steps_max=7) # Round number shouldn't matter. recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5], [0, 1], [2, 3], [4, 5]]) self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])
def test_build_dataset_split_fn_recon_epochs_variable(self): # 3 batches. client_dataset = tf.data.Dataset.range(6).batch(2) split_dataset_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=8, recon_epochs_constant=False) round_num = tf.constant(1, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5]]) self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]]) round_num = tf.constant(2, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5], [0, 1], [2, 3], [4, 5]]) self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])
def test_build_dataset_split_fn_split_dataset_one_batch(self): """Ensures clients without any data don't fail.""" # 1 batch. Batch size can be larger than number of examples. client_dataset = tf.data.Dataset.range(1).batch(4) split_dataset_fn = reconstruction_utils.build_dataset_split_fn( split_dataset=True) # Round number doesn't matter. round_num = tf.constant(1, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0]]) self.assertAllEqual(post_recon_list, []) # Round number doesn't matter. round_num = tf.constant(2, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0]]) self.assertAllEqual(post_recon_list, [])
def test_build_dataset_split_fn_split_dataset_even_batches(self): # 4 batches. client_dataset = tf.data.Dataset.range(8).batch(2) split_dataset_fn = reconstruction_utils.build_dataset_split_fn( split_dataset=True) # Round number doesn't matter. round_num = tf.constant(1, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [4, 5]]) self.assertAllEqual(post_recon_list, [[2, 3], [6, 7]]) # Round number doesn't matter. round_num = tf.constant(2, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [4, 5]]) self.assertAllEqual(post_recon_list, [[2, 3], [6, 7]])
def test_build_dataset_split_fn_split_dataset_zero_batches(self): """Ensures clients without any data don't fail.""" # 0 batches. client_dataset = tf.data.Dataset.range(0).batch(2) split_dataset_fn = reconstruction_utils.build_dataset_split_fn( split_dataset=True) # Round number doesn't matter. round_num = tf.constant(1, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, []) self.assertAllEqual(post_recon_list, []) # Round number doesn't matter. round_num = tf.constant(2, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, []) self.assertAllEqual(post_recon_list, [])
def test_federated_reconstruction_evaluation_process_no_recon( self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] dataset_split_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=0, post_recon_epochs=2) evaluator = evaluation_computation.build_federated_reconstruction_evaluation_process( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluator.initialize.type_signature), '( -> <model=<trainable=<float32[1,1]>,non_trainable=<>>,' 'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER)') self.assertEqual( str(evaluator.next.type_signature), '(<state=<model=<trainable=<float32[1,1]>,non_trainable=<>>,' 'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,' 'data={<x=float32[?,1],y=float32[?,1]>*}@CLIENTS> -> ' '<<model=<trainable=<float32[1,1]>,non_trainable=<>>,' 'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,' '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER>)' ) state = evaluator.initialize() state, metrics = evaluator.next(state, create_client_data()) expected_keys = ['loss', 'num_examples_total', 'num_over'] self.assertCountEqual(metrics.keys(), expected_keys) self.assertAlmostEqual(metrics['num_examples_total'], 12.0) self.assertAlmostEqual(metrics['num_over'], 6.0) # Without reconstruction and with an initialized model, we can expect an # exact value for loss. state = reconstruction_utils.ServerState( model=collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), optimizer_state=(), round_num=tf.constant(0, dtype=tf.int64), aggregator_state=(), ) state, metrics = evaluator.next(state, create_client_data()) expected_keys = ['loss', 'num_examples_total', 'num_over'] self.assertCountEqual(metrics.keys(), expected_keys) # MSE is (y - 1 * x)^2 for each example, for a mean of # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3. self.assertAlmostEqual(metrics['loss'], 19.666666) self.assertAlmostEqual(metrics['num_examples_total'], 12.0) self.assertAlmostEqual(metrics['num_over'], 6.0)
def test_build_dataset_split_fn_post_recon_multiple_epochs_max_steps(self): # 3 batches. client_dataset = tf.data.Dataset.range(6).batch(2) split_dataset_fn = reconstruction_utils.build_dataset_split_fn( post_recon_epochs=2, post_recon_steps_max=4) # Round number doesn't matter. round_num = tf.constant(1, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5]]) self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5], [0, 1]]) # Round number doesn't matter. round_num = tf.constant(2, dtype=tf.int64) recon_dataset, post_recon_dataset = split_dataset_fn( client_dataset, round_num) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5]]) self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5], [0, 1]])
def test_federated_reconstruction_metrics_none_loss_decreases( self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() dataset_split_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=3) evaluate = evaluation_computation.build_federated_reconstruction_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=None, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.01), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> <loss=float32>@SERVER)') result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) expected_keys = ['loss'] self.assertCountEqual(result.keys(), expected_keys) # Ensure loss decreases from reconstruction vs. initializing the bias to 0. # MSE is (y - 1 * x)^2 for each example, for a mean of # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3. self.assertLess(result['loss'], 19.666666)
def test_federated_reconstruction_no_split_data(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] dataset_split_fn = reconstruction_utils.build_dataset_split_fn() evaluate = evaluation_computation.build_federated_reconstruction_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> ' '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)' ) result = evaluate( collections.OrderedDict([ ('trainable', [[[5.0]]]), ('non_trainable', []), ]), create_client_data()) expected_keys = ['loss', 'num_examples_total', 'num_over'] self.assertCountEqual(result.keys(), expected_keys) self.assertAlmostEqual(result['num_examples_total'], 6.0) self.assertAlmostEqual(result['num_over'], 3.0)
def test_custom_model_eval_reconstruction_multiple_epochs(self): client_data = create_emnist_client_data() train_data = [client_data(), client_data()] def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] dataset_split_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=3, recon_epochs_constant=False, post_recon_epochs=4, post_recon_steps_max=3) trainer = training_process.build_federated_reconstruction_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.001), reconstruction_optimizer_fn=functools.partial( tf.keras.optimizers.SGD, 0.001), evaluate_reconstruction=True, dataset_split_fn=dataset_split_fn) state = trainer.initialize() outputs = [] states = [] for _ in range(2): state, output = trainer.next(state, train_data) outputs.append(output) states.append(state) self.assertLess(outputs[1]['loss'], outputs[0]['loss']) self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable) # Expect 6 reconstruction examples, 10 training examples. self.assertEqual(outputs[0]['num_examples_total'], 16.0) # Expect 12 reconstruction examples, 10 training examples. self.assertEqual(outputs[1]['num_examples_total'], 22.0) # Expect 4 reconstruction batches and 6 training batches. self.assertEqual(outputs[0]['num_batches_total'], 10.0) # Expect 8 reconstruction batches and 6 training batches. self.assertEqual(outputs[1]['num_batches_total'], 14.0)
def test_build_dataset_split_fn(self): # 3 batches. client_dataset = tf.data.Dataset.range(6).batch(2) split_dataset_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=2, post_recon_epochs=1) # Round number shouldn't matter. recon_dataset, post_recon_dataset = split_dataset_fn(client_dataset, 3) recon_list = list(recon_dataset.as_numpy_iterator()) post_recon_list = list(post_recon_dataset.as_numpy_iterator()) self.assertAllEqual(recon_list, [[0, 1], [2, 3], [4, 5], [0, 1], [2, 3], [4, 5]]) self.assertAllEqual(post_recon_list, [[0, 1], [2, 3], [4, 5]])
def test_custom_model_eval_reconstruction_split_multiple_epochs(self): client_data = create_emnist_client_data() # 3 batches per user, each with one example. Since data will be split for # each user, each user will have 2 unique recon examples, and 1 unique # post-recon example (even-indices are allocated to recon during splitting). train_data = [client_data(batch_size=1), client_data(batch_size=1)] def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] dataset_split_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=3, split_dataset=True, post_recon_epochs=5) trainer = training_process.build_federated_reconstruction_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=functools.partial(tf.keras.optimizers.SGD, 0.001), evaluate_reconstruction=True, dataset_split_fn=dataset_split_fn) state = trainer.initialize() outputs = [] states = [] for _ in range(2): state, output = trainer.next(state, train_data) outputs.append(output) states.append(state) self.assertLess(outputs[1]['loss'], outputs[0]['loss']) self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable) # Expect 12 reconstruction examples, 10 training examples. self.assertEqual(outputs[0]['num_examples_total'], 22.0) self.assertEqual(outputs[1]['num_examples_total'], 22.0) # Expect 12 reconstruction batches and 10 training batches. self.assertEqual(outputs[0]['num_batches_total'], 22.0) self.assertEqual(outputs[1]['num_batches_total'], 22.0)
def test_personal_matrix_factorization_trains_reconstruction_model(self): train_data = [ self.train_users.flatten().tolist(), self.train_items.flatten().tolist(), self.train_preferences.flatten().tolist() ] train_tf_dataset = tf.data.Dataset.from_tensor_slices( list(zip(*train_data))) def batch_map_fn(example_batch): return collections.OrderedDict( x=tf.cast(example_batch[:, 0:1], tf.int64), y=example_batch[:, 1:2]) train_tf_dataset = train_tf_dataset.batch(1).map(batch_map_fn).repeat(5) train_tf_datasets = [train_tf_dataset] * 2 num_users = 1 num_items = 8 num_latent_factors = 10 personal_model = True add_biases = False l2_regularization = 0.0 tff_model_fn = models.build_reconstruction_model( functools.partial( models.get_matrix_factorization_model, num_users, num_items, num_latent_factors, personal_model=personal_model, add_biases=add_biases, l2_regularization=l2_regularization)) # Also test `models.get_loss_fn` and `models.get_metrics_fn`. trainer = training_process.build_federated_reconstruction_process( tff_model_fn, loss_fn=models.get_loss_fn(), metrics_fn=models.get_metrics_fn(), client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1e-2), reconstruction_optimizer_fn=( lambda: tf.keras.optimizers.SGD(learning_rate=1e-3)), dataset_split_fn=reconstruction_utils.build_dataset_split_fn( recon_epochs_max=10)) state = trainer.initialize() trainer.next(state, train_tf_datasets)
def test_federated_reconstruction_evaluation_process(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] dataset_split_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=2, post_recon_epochs=10, post_recon_steps_max=7, split_dataset=True) evaluator = evaluation_computation.build_federated_reconstruction_evaluation_process( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluator.initialize.type_signature), '( -> <model=<trainable=<float32[1,1]>,non_trainable=<>>,' 'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER)') self.assertEqual( str(evaluator.next.type_signature), '(<state=<model=<trainable=<float32[1,1]>,non_trainable=<>>,' 'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,' 'data={<x=float32[?,1],y=float32[?,1]>*}@CLIENTS> -> ' '<<model=<trainable=<float32[1,1]>,non_trainable=<>>,' 'optimizer_state=<>,round_num=int64,aggregator_state=<>>@SERVER,' '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER>)' ) state = evaluator.initialize() state, metrics = evaluator.next(state, create_client_data()) expected_keys = ['loss', 'num_examples_total', 'num_over'] self.assertCountEqual(metrics.keys(), expected_keys) self.assertAlmostEqual(metrics['num_examples_total'], 14.0) self.assertAlmostEqual(metrics['num_over'], 7.0)
def test_federated_reconstruction_recon_lr_0(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] dataset_split_fn = reconstruction_utils.build_dataset_split_fn() evaluate = evaluation_computation.build_federated_reconstruction_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, # Set recon optimizer LR to 0 so reconstruction has no effect. reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.0), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> ' '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)' ) result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) expected_keys = ['loss', 'num_examples_total', 'num_over'] self.assertCountEqual(result.keys(), expected_keys) # Now have an expectation for loss since the local bias is initialized at 0 # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3. self.assertAlmostEqual(result['loss'], 19.666666) self.assertAlmostEqual(result['num_examples_total'], 6.0) self.assertAlmostEqual(result['num_over'], 3.0)
def build_federated_reconstruction_evaluation( model_fn: ModelFn, *, # Callers pass below args by name. loss_fn: LossFn, metrics_fn: Optional[MetricsFn], reconstruction_optimizer_fn: OptimizerFn = functools.partial( tf.keras.optimizers.SGD, 0.1), dataset_split_fn: Optional[reconstruction_utils.DatasetSplitFn] = None ) -> tff.Computation: """Builds a `tff.Computation` for evaluation of a `ReconstructionModel`. The returned computation proceeds in two stages: (1) reconstruction and (2) evaluation. During the reconstruction stage, local variables are reconstructed by freezing global variables and training using reconstruction_optimizer_fn. During the evaluation stage, the reconstructed local variables and global variables are evaluated using the provided loss_fn and metrics_fn. Usage of returned computation: eval_comp = build_federated_reconstruction_evaluation(...) metrics = eval_comp(reconstruction_utils.get_global_variables(model), federated_data) Args: model_fn: A no-arg function that returns a `ReconstructionModel`. This method must *not* capture Tensorflow tensors or variables and use them. Must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to evaluate the model. The loss will be applied to the model's outputs during the evaluation stage. The final loss metric is the example-weighted mean loss across batches (and across clients). metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to evaluate the model. The metrics will be applied to the model's outputs during the evaluation stage. Final metric values are the example-weighted mean of metric values across batches (and across clients). If None, no metrics are applied. reconstruction_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` used to reconstruct the local variables with the global ones frozen. dataset_split_fn: A `reconstruction_utils.DatasetSplitFn` taking in a client dataset and round number (always 0 for evaluation) and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over during evaluation. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and evaluation. If None, split client data in half for each user, using one half for reconstruction and the other for evaluation. See `reconstruction_utils.build_dataset_split_fn` for options. Raises: ValueError: if both `loss_fn` and `metrics_fn` are None. Returns: A `tff.Computation` that accepts model parameters and federated data and returns example-weighted evaluation loss and metrics. """ # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. with tf.Graph().as_default(): model = model_fn() global_weights = reconstruction_utils.get_global_variables(model) model_weights_type = tff.framework.type_from_tensors(global_weights) batch_type = tff.to_type(model.input_spec) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) if not metrics: raise ValueError( 'One or both of metrics_fn and loss_fn should be provided.') federated_output_computation = ( keras_utils.federated_output_computation_from_metrics(metrics)) # Remove unneeded variables to avoid polluting namespace. del model del global_weights del metrics if dataset_split_fn is None: dataset_split_fn = reconstruction_utils.build_dataset_split_fn( split_dataset=True) @tff.tf_computation(model_weights_type, tff.SequenceType(batch_type)) def client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" client_model = model_fn() client_global_weights = reconstruction_utils.get_global_variables( client_model) client_local_weights = reconstruction_utils.get_local_variables( client_model) metrics = [keras_utils.MeanLossMetric(loss_fn())] if metrics_fn is not None: metrics.extend(metrics_fn()) batch_loss_fn = loss_fn() reconstruction_optimizer = reconstruction_optimizer_fn() @tf.function def reconstruction_reduce_fn(num_examples_sum, batch): """Runs reconstruction training on local client batch.""" with tf.GradientTape() as tape: output = client_model.forward_pass(batch, training=True) batch_loss = batch_loss_fn( y_true=output.labels, y_pred=output.predictions) gradients = tape.gradient(batch_loss, client_local_weights.trainable) reconstruction_optimizer.apply_gradients( zip(gradients, client_local_weights.trainable)) return num_examples_sum + output.num_examples @tf.function def evaluation_reduce_fn(num_examples_sum, batch): """Runs evaluation on client batch without training.""" output = client_model.forward_pass(batch, training=False) # Update each metric. for metric in metrics: metric.update_state(y_true=output.labels, y_pred=output.predictions) return num_examples_sum + output.num_examples @tf.function def tf_client_computation(incoming_model_weights, client_dataset): """Reconstructs and evaluates with `incoming_model_weights`.""" # Pass in fixed 0 round number during evaluation, since global variables # aren't being iteratively updated as in training. recon_dataset, eval_dataset = dataset_split_fn( client_dataset, tf.constant(0, dtype=tf.int64)) # Assign incoming global weights to `client_model` before reconstruction. tf.nest.map_structure(lambda v, t: v.assign(t), client_global_weights, incoming_model_weights) recon_dataset.reduce(tf.constant(0), reconstruction_reduce_fn) eval_dataset.reduce(tf.constant(0), evaluation_reduce_fn) eval_local_outputs = keras_utils.read_metric_variables(metrics) return eval_local_outputs return tf_client_computation(incoming_model_weights, client_dataset) @tff.federated_computation( tff.type_at_server(model_weights_type), tff.type_at_clients(tff.SequenceType(batch_type))) def server_eval(server_model_weights, federated_dataset): client_outputs = tff.federated_map( client_computation, [tff.federated_broadcast(server_model_weights), federated_dataset]) return federated_output_computation(client_outputs) return server_eval