def test_federated_reconstruction_metrics_none_loss_decreases( self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() dataset_split_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs_max=3) evaluate = evaluation_computation.build_federated_reconstruction_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=None, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.01), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> <loss=float32>@SERVER)') result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) expected_keys = ['loss'] self.assertCountEqual(result.keys(), expected_keys) # Ensure loss decreases from reconstruction vs. initializing the bias to 0. # MSE is (y - 1 * x)^2 for each example, for a mean of # (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3. self.assertLess(result['loss'], 19.666666)
def test_federated_reconstruction_split_data(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] evaluate = evaluation_computation.build_federated_reconstruction_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1)) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> ' '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)' ) result = evaluate( collections.OrderedDict([ ('trainable', [[[5.0]]]), ('non_trainable', []), ]), create_client_data()) expected_keys = ['loss', 'num_examples_total', 'num_over'] self.assertCountEqual(result.keys(), expected_keys) self.assertAlmostEqual(result['num_examples_total'], 2.0) self.assertAlmostEqual(result['num_over'], 1.0)
def test_federated_reconstruction_skip_recon(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] # Ensure reconstruction is skipped if `recon_dataset` is empty. This also # ensures `round_num` is 0 for evaluation and loss doesn't change if # `eval_dataset` is repeated. def dataset_split_fn(client_dataset, round_num): return client_dataset.repeat(round_num), client_dataset.repeat(2) evaluate = evaluation_computation.build_federated_reconstruction_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> ' '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)' ) result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) expected_keys = ['loss', 'num_examples_total', 'num_over'] self.assertCountEqual(result.keys(), expected_keys) # Now have an expectation for loss since the local bias is initialized at 0 # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3 self.assertAlmostEqual(result['loss'], 19.666666) self.assertAlmostEqual(result['num_examples_total'], 12.0) self.assertAlmostEqual(result['num_over'], 6.0)
def test_federated_reconstruction_recon_lr_0(self, model_fn): def loss_fn(): return tf.keras.losses.MeanSquaredError() def metrics_fn(): return [NumExamplesCounter(), NumOverCounter(5.0)] dataset_split_fn = reconstruction_utils.build_dataset_split_fn() evaluate = evaluation_computation.build_federated_reconstruction_evaluation( model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, # Set recon optimizer LR to 0 so reconstruction has no effect. reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.0), dataset_split_fn=dataset_split_fn) self.assertEqual( str(evaluate.type_signature), '(<server_model_weights=<trainable=<float32[1,1]>,' 'non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,1],' 'y=float32[?,1]>*}@CLIENTS> -> ' '<loss=float32,num_examples_total=float32,num_over=float32>@SERVER)' ) result = evaluate( collections.OrderedDict([ ('trainable', [[[1.0]]]), ('non_trainable', []), ]), create_client_data()) expected_keys = ['loss', 'num_examples_total', 'num_over'] self.assertCountEqual(result.keys(), expected_keys) # Now have an expectation for loss since the local bias is initialized at 0 # and not reconstructed. MSE is (y - 1 * x)^2 for each example, for a mean # of (4^2 + 4^2 + 5^2 + 4^2 + 3^2 + 6^2) / 6 = 59/3. self.assertAlmostEqual(result['loss'], 19.666666) self.assertAlmostEqual(result['num_examples_total'], 6.0) self.assertAlmostEqual(result['num_over'], 3.0)
def evaluation_computation_builder( model_fn: Callable[[], reconstruction_model.ReconstructionModel], loss_fn: Callable[[], tf.losses.Loss], metrics_fn: Callable[[], List[tf.metrics.Metric]], dataset_split_fn_builder: Callable[ ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils .build_dataset_split_fn, task_name: str = 'stackoverflow_nwp', ) -> tff.Computation: """Creates an evaluation computation using federated reconstruction.""" # For a `stackoverflow_nwp_finetune` task, the first dataset returned by # `dataset_split_fn` is used for fine-tuning global variables. For other # tasks, the first dataset is used for reconstructing local variables. dataset_split_fn = dataset_split_fn_builder( recon_epochs_max=1, recon_epochs_constant=1, recon_steps_max=1, post_recon_epochs=1, post_recon_steps_max=1, split_dataset=True) if task_name == 'stackoverflow_nwp_finetune': return federated_evaluation.build_federated_finetune_evaluation( model_fn=model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, finetune_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0), dataset_split_fn=dataset_split_fn) return evaluation_computation.build_federated_reconstruction_evaluation( model_fn=model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0), dataset_split_fn=dataset_split_fn)
def evaluation_computation_builder( model_fn: Callable[[], reconstruction_model.ReconstructionModel], loss_fn: Callable[[], tf.losses.Loss], metrics_fn: Callable[[], List[tf.metrics.Metric]], dataset_split_fn_builder: Callable[ ..., reconstruction_utils.DatasetSplitFn] = reconstruction_utils. build_dataset_split_fn, ) -> tff.Computation: """Creates a `tff.Computation` for federated evaluation. For a `stackoverflow_nwp_finetune` task, the returned `tff.Computation` is created by `federated_evaluation.build_federated_finetune_evaluation`. For other tasks, the returned `tff.Computation` is given by `evaluation_computation.build_federated_reconstruction_evaluation`. Args: model_fn: A no-arg function that returns a `ReconstructionModel`. The returned model must have only global variables for a `stackoverflow_nwp_finetune` task. This method must *not* capture Tensorflow tensors or variables and use them. Must be constructed entirely from scratch on each invocation, returning the same model each call will result in an error. loss_fn: A no-arg function returning a `tf.keras.losses.Loss` to use to evaluate the model. The final loss metric is the example-weighted mean loss across batches (and across clients). metrics_fn: A no-arg function returning a list of `tf.keras.metrics.Metric`s to use to evaluate the model. The final metrics are the example-weighted mean metrics across batches (and across clients). dataset_split_fn_builder: `DatasetSplitFn` builder. Returns a method used to split the examples into a reconstruction set (which is used as a fine-tuning set for a `stackoverflow_nwp_finetune` task), and an evaluation set. Returns: A `tff.Computation` for federated evaluation. """ # For a `stackoverflow_nwp_finetune` task, the first dataset returned by # `dataset_split_fn` is used for fine-tuning global variables. For other # tasks, the first dataset is used for reconstructing local variables. dataset_split_fn = dataset_split_fn_builder( recon_epochs_max=FLAGS.recon_epochs_max, recon_epochs_constant=FLAGS.recon_epochs_constant, recon_steps_max=FLAGS.recon_steps_max, post_recon_epochs=FLAGS.post_recon_epochs, post_recon_steps_max=FLAGS.post_recon_steps_max, # Getting meaningful evaluation metrics requires splitting the data. split_dataset=True) if FLAGS.task == 'stackoverflow_nwp_finetune': return federated_evaluation.build_federated_finetune_evaluation( model_fn=model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, finetune_optimizer_fn=functools.partial( finetune_optimizer_fn, FLAGS.finetune_learning_rate), dataset_split_fn=dataset_split_fn) return evaluation_computation.build_federated_reconstruction_evaluation( model_fn=model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, reconstruction_optimizer_fn=functools.partial( reconstruction_optimizer_fn, FLAGS.reconstruction_learning_rate), dataset_split_fn=dataset_split_fn)