def test_quantile_aggregation_for_num_examples(self): client_ids = [0, 1, 2, 3, 4] quantiles = [0, 0.25, 0.5, 0.75, 1.0] def create_single_value_ds(client_id): client_value = [[0.0]] * (client_id + 1) return tf.data.Dataset.from_tensor_slices( collections.OrderedDict(x=client_value, y=client_value)).batch(1) client_data = [create_single_value_ds(id) for id in client_ids] metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] eval_fn = training_utils.build_federated_evaluate_fn( model_builder, metrics_builder, quantiles=quantiles) model_weights = tff.learning.ModelWeights.from_model(tff_model_fn()) eval_metrics = eval_fn(model_weights, client_data) logging.info('Eval metrics: %s', eval_metrics) num_examples_quantiles = eval_metrics['num_examples']['quantiles'] expected_quantile_values = [1.0, 2.0, 3.0, 4.0, 5.0] expected_quantiles = collections.OrderedDict( zip(quantiles, expected_quantile_values)) self.assertAllClose(num_examples_quantiles, expected_quantiles)
def test_eval_metrics_for_unbalanced_client_data(self): client_data = self._create_client_data(balanced=False) metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] eval_fn = training_utils.build_federated_evaluate_fn( model_builder, metrics_builder) model_weights = tff.learning.ModelWeights.from_model(tff_model_fn()) eval_metrics = eval_fn(model_weights, client_data) logging.info('Eval metrics: %s', eval_metrics) self.assertCountEqual(eval_metrics.keys(), ['mean_squared_error', 'num_examples']) expected_keys = ['example_weighted', 'uniform_weighted', 'quantiles'] self.assertCountEqual(eval_metrics['num_examples'].keys(), expected_keys) self.assertCountEqual(eval_metrics['mean_squared_error'].keys(), expected_keys) self.assertNear( eval_metrics['num_examples']['uniform_weighted'], 2.5, err=1e-6) self.assertNear( eval_metrics['num_examples']['example_weighted'], 2.6, err=1e-6) expected_uniform_mse = 1.5 expected_example_mse = 1.8 self.assertNear( eval_metrics['mean_squared_error']['uniform_weighted'], expected_uniform_mse, err=1e-6) self.assertNear( eval_metrics['mean_squared_error']['example_weighted'], expected_example_mse, err=1e-6)
def test_quantile_aggregation_for_mse(self): client_ids = [0, 1, 2, 3, 4] quantiles = [0, 0.25, 0.5, 0.75, 1.0] def create_single_value_ds(client_id): client_value = [[float(client_id)]] return tf.data.Dataset.from_tensor_slices( collections.OrderedDict(x=client_value, y=client_value)).batch(1) client_data = tff.simulation.client_data.ConcreteClientData( client_ids=client_ids, create_tf_dataset_for_client_fn=create_single_value_ds) metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] eval_fn = training_utils.build_federated_evaluate_fn( client_data, model_builder, metrics_builder, clients_per_round=5, quantiles=quantiles) model_weights = tff.learning.ModelWeights.from_model(tff_model_fn()) eval_metrics = eval_fn(model_weights, round_num=1) logging.info('Eval metrics: %s', eval_metrics) mse_quantiles = eval_metrics['mean_squared_error']['quantiles'] expected_quantile_values = [0.0, 1.0, 4.0, 9.0, 16.0] expected_quantiles = collections.OrderedDict( zip(quantiles, expected_quantile_values)) self.assertEqual(mse_quantiles, expected_quantiles)
def test_eval_metrics_for_balanced_client_data(self): client_data = self._create_client_data(balanced=True) metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] eval_fn = training_utils.build_federated_evaluate_fn( client_data, model_builder, metrics_builder, clients_per_round=2) model_weights = tff.learning.ModelWeights.from_model(tff_model_fn()) eval_metrics = eval_fn(model_weights, round_num=1) logging.info('Eval metrics: %s', eval_metrics) self.assertCountEqual(eval_metrics.keys(), ['mean_squared_error', 'num_examples']) # Testing correctness of sum-based metrics expected_sum_keys = ['summed', 'uniform_weighted', 'quantiles'] self.assertCountEqual(eval_metrics['num_examples'].keys(), expected_sum_keys) self.assertEqual(eval_metrics['num_examples']['summed'], 6) self.assertEqual(eval_metrics['num_examples']['uniform_weighted'], 3.0) # Testing correctness of mean-based metrics expected_keys = ['example_weighted', 'uniform_weighted', 'quantiles'] self.assertCountEqual(eval_metrics['mean_squared_error'].keys(), expected_keys) expected_mse = 0.875 self.assertNear(eval_metrics['mean_squared_error']['uniform_weighted'], expected_mse, err=1e-6) self.assertNear(eval_metrics['mean_squared_error']['example_weighted'], expected_mse, err=1e-6)
def _evaluation_fn(): metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] fed_eval_fn = training_utils.build_federated_evaluate_fn( model_builder=_keras_model_builder, metrics_builder=metrics_builder) fed_data = _federated_data() def eval_fn(model, round_num): del round_num return fed_eval_fn(model, fed_data) return eval_fn
def configure_training(task_spec: training_specs.TaskSpec, eval_spec: Optional[training_specs.EvalSpec] = None, model: str = 'cnn') -> training_specs.RunnerSpec: """Configures training for the EMNIST character recognition task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. eval_spec: An `EvalSpec` class for configuring federated evaluation. If set to None, centralized evaluation is used for validation and testing instead. model: A string specifying the model used for character recognition. Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely connected 2-layer model (respectively). Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ emnist_task = 'digit_recognition' emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data( only_digits=False) train_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, emnist_task=emnist_task) input_spec = train_preprocess_fn.type_signature.result.element if model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError( 'Cannot handle model flag [{!s}], must be one of {!s}.'.format( model, EMNIST_MODELS)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) clients_per_train_round = min(task_spec.clients_per_round, TOTAL_NUM_TRAIN_CLIENTS) if hasattr(emnist_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = emnist_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( emnist_train.client_ids, size=clients_per_train_round, replace=False, random_seed=task_spec.sampling_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=emnist_train, clients_per_round=clients_per_train_round, random_seed=task_spec.sampling_random_seed) training_process.get_model_weights = iterative_process.get_model_weights if eval_spec: if eval_spec.clients_per_validation_round is None: clients_per_validation_round = TOTAL_NUM_TEST_CLIENTS else: clients_per_validation_round = min(eval_spec.clients_per_validation_round, TOTAL_NUM_TEST_CLIENTS) if eval_spec.clients_per_test_round is None: clients_per_test_round = TOTAL_NUM_TEST_CLIENTS else: clients_per_test_round = min(eval_spec.clients_per_test_round, TOTAL_NUM_TEST_CLIENTS) test_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=1, batch_size=eval_spec.client_batch_size, shuffle_buffer_size=1, emnist_task=emnist_task) emnist_test = emnist_test.preprocess(test_preprocess_fn) def eval_metrics_builder(): return [ tf.keras.metrics.SparseCategoricalCrossentropy(), tf.keras.metrics.SparseCategoricalAccuracy() ] federated_eval_fn = training_utils.build_federated_evaluate_fn( model_builder=model_builder, metrics_builder=eval_metrics_builder) validation_client_sampling_fn = training_utils.build_client_datasets_fn( emnist_test, clients_per_validation_round, random_seed=eval_spec.sampling_random_seed) test_client_sampling_fn = training_utils.build_client_datasets_fn( emnist_test, clients_per_test_round, random_seed=eval_spec.sampling_random_seed) def validation_fn(model_weights, round_num): validation_clients = validation_client_sampling_fn(round_num) return federated_eval_fn(model_weights, validation_clients) def test_fn(model_weights): # We fix the round number to get deterministic behavior test_round_num = 0 test_clients = test_client_sampling_fn(test_round_num) return federated_eval_fn(model_weights, test_clients) else: _, central_emnist_test = emnist_dataset.get_centralized_datasets( only_digits=False, emnist_task=emnist_task) test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=central_emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) return training_specs.RunnerSpec( iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)