def configure_training(task_spec: training_specs.TaskSpec, model: str = 'cnn') -> training_specs.RunnerSpec: """Configures training for the EMNIST character recognition task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. model: A string specifying the model used for character recognition. Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely connected 2-layer model (respectively). Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ emnist_task = 'digit_recognition' emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False) _, emnist_test = emnist_dataset.get_centralized_datasets( only_digits=False, emnist_task=emnist_task) train_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, emnist_task=emnist_task) input_spec = train_preprocess_fn.type_signature.result.element if model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError( 'Cannot handle model flag [{!s}], must be one of {!s}.'.format( model, EMNIST_MODELS)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = emnist_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = tff.simulation.build_uniform_sampling_fn( emnist_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) training_process.get_model_weights = iterative_process.get_model_weights evaluate_fn = tff.learning.build_federated_evaluation(tff_model_fn) def test_fn(state): return evaluate_fn( iterative_process.get_model_weights(state), [emnist_test]) def validation_fn(state, round_num): del round_num return evaluate_fn( iterative_process.get_model_weights(state), [emnist_test]) return training_specs.RunnerSpec( iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def configure_training( task_spec: training_specs.TaskSpec) -> training_specs.RunnerSpec: """Configures training for the EMNIST autoencoder task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ emnist_task = 'autoencoder' emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False) _, emnist_test = emnist_dataset.get_centralized_datasets( only_digits=False, emnist_task=emnist_task) train_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, emnist_task=emnist_task) input_spec = train_preprocess_fn.type_signature.result.element model_builder = emnist_ae_models.create_autoencoder_model loss_builder = functools.partial( tf.keras.losses.MeanSquaredError, reduction=tf.keras.losses.Reduction.SUM) metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) if hasattr(emnist_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = emnist_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( emnist_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.sampling_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=emnist_train, clients_per_round=task_spec.clients_per_round, random_seed=task_spec.sampling_random_seed) training_process.get_model_weights = iterative_process.get_model_weights test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) return training_specs.RunnerSpec( iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def configure_training( task_spec: training_specs.TaskSpec, crop_size: int = 24, distort_train_images: bool = True) -> training_specs.RunnerSpec: """Configures training for the CIFAR-100 classification task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. crop_size: An optional integer representing the resulting size of input images after preprocessing. distort_train_images: A boolean indicating whether to distort training images during preprocessing via random crops, as opposed to simply resizing the image. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ crop_shape = (crop_size, crop_size, 3) cifar_train, _ = tff.simulation.datasets.cifar100.load_data() _, cifar_test = cifar100_dataset.get_centralized_datasets( train_batch_size=task_spec.client_batch_size, crop_shape=crop_shape) train_preprocess_fn = cifar100_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, crop_shape=crop_shape, distort_image=distort_train_images) input_spec = train_preprocess_fn.type_signature.result.element model_builder = functools.partial(resnet_models.create_resnet18, input_shape=crop_shape, num_classes=NUM_CLASSES) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = cifar_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( cifar_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) training_process.get_model_weights = iterative_process.get_model_weights centralized_eval_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=cifar_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) def test_fn(state): return centralized_eval_fn(iterative_process.get_model_weights(state)) def validation_fn(state, round_num): del round_num return test_fn(state) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def configure_training(task_spec: training_specs.TaskSpec, sequence_length: int = 80) -> training_specs.RunnerSpec: """Configures training for the Shakespeare next-character prediction task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. sequence_length: An int specifying the length of the character sequences used for prediction. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ shakespeare_train, _ = tff.simulation.datasets.shakespeare.load_data() _, shakespeare_test = shakespeare_dataset.get_centralized_datasets( sequence_length=sequence_length) train_preprocess_fn = shakespeare_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, sequence_length=sequence_length) input_spec = train_preprocess_fn.type_signature.result.element model_builder = functools.partial(create_shakespeare_model, sequence_length=sequence_length) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) if hasattr(shakespeare_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = shakespeare_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( shakespeare_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=shakespeare_train, clients_per_round=task_spec.clients_per_round, random_seed=task_spec.client_datasets_random_seed) training_process.get_model_weights = iterative_process.get_model_weights test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=shakespeare_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def configure_training( task_spec: training_specs.TaskSpec, vocab_size: int = 10000, num_oov_buckets: int = 1, sequence_length: int = 20, max_elements_per_user: int = 1000, num_validation_examples: int = 10000, embedding_size: int = 96, latent_size: int = 670, num_layers: int = 1, shared_embedding: bool = False) -> training_specs.RunnerSpec: """Configures training for Stack Overflow next-word prediction. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. vocab_size: Integer dictating the number of most frequent words to use in the vocabulary. num_oov_buckets: The number of out-of-vocabulary buckets to use. sequence_length: The maximum number of words to take for each sequence. max_elements_per_user: The maximum number of elements processed for each client's dataset. num_validation_examples: The number of test examples to use for validation. embedding_size: The dimension of the word embedding layer. latent_size: The dimension of the latent units in the recurrent layers. num_layers: The number of stacked recurrent layers to use. shared_embedding: Boolean indicating whether to tie input and output embeddings. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=vocab_size, num_oov_buckets=num_oov_buckets, embedding_size=embedding_size, latent_size=latent_size, num_layers=num_layers, shared_embedding=shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( vocab_size, num_oov_buckets) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]) ] train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data() # TODO(b/161914546): consider moving evaluation to use # `tff.learning.build_federated_evaluation` to get metrics over client # distributions, as well as the example weight means from this centralized # evaluation. _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=vocab_size, max_sequence_length=sequence_length, num_validation_examples=num_validation_examples, num_oov_buckets=num_oov_buckets) train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn( vocab=stackoverflow_word_prediction.create_vocab(vocab_size), num_oov_buckets=num_oov_buckets, client_batch_size=task_spec.client_batch_size, client_epochs_per_round=task_spec.client_epochs_per_round, max_sequence_length=sequence_length, max_elements_per_client=max_elements_per_user) input_spec = train_dataset_preprocess_comp.type_signature.result.element def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) @tff.tf_computation(tf.string) def train_dataset_computation(client_id): client_train_data = train_clientdata.dataset_computation(client_id) return train_dataset_preprocess_comp(client_train_data) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_dataset_computation, iterative_process) client_ids_fn = training_utils.build_sample_fn( train_clientdata.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) training_process.get_model_weights = iterative_process.get_model_weights centralized_validation_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) def validation_fn(server_state, round_num): del round_num return centralized_validation_fn( iterative_process.get_model_weights(server_state)) centralized_test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) def test_fn(server_state): return centralized_test_fn( iterative_process.get_model_weights(server_state)) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def configure_training( task_spec: training_specs.TaskSpec, vocab_tokens_size: int = 10000, vocab_tags_size: int = 500, max_elements_per_user: int = 1000, num_validation_examples: int = 10000) -> training_specs.RunnerSpec: """Configures training for the Stack Overflow tag prediction task. This tag prediction is performed via multi-class one-versus-rest logistic regression. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. vocab_tokens_size: Integer dictating the number of most frequent words to use in the vocabulary. vocab_tags_size: Integer dictating the number of most frequent tags to use in the label creation. max_elements_per_user: The maximum number of elements processed for each client's dataset. num_validation_examples: The number of test examples to use for validation. Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ stackoverflow_train, _, _ = tff.simulation.datasets.stackoverflow.load_data( ) _, stackoverflow_validation, stackoverflow_test = stackoverflow_tag_prediction.get_centralized_datasets( train_batch_size=task_spec.client_batch_size, word_vocab_size=vocab_tokens_size, tag_vocab_size=vocab_tags_size, num_validation_examples=num_validation_examples) word_vocab = stackoverflow_tag_prediction.create_word_vocab( vocab_tokens_size) tag_vocab = stackoverflow_tag_prediction.create_tag_vocab(vocab_tags_size) train_preprocess_fn = stackoverflow_tag_prediction.create_preprocess_fn( word_vocab=word_vocab, tag_vocab=tag_vocab, client_batch_size=task_spec.client_batch_size, client_epochs_per_round=task_spec.client_epochs_per_round, max_elements_per_client=max_elements_per_user) input_spec = train_preprocess_fn.type_signature.result.element model_builder = functools.partial( stackoverflow_lr_models.create_logistic_model, vocab_tokens_size=vocab_tokens_size, vocab_tags_size=vocab_tags_size) loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy, from_logits=False, reduction=tf.keras.losses.Reduction.SUM) def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model(keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = stackoverflow_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( stackoverflow_train.client_ids, size=task_spec.clients_per_round, replace=False, random_seed=task_spec.client_datasets_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) training_process.get_model_weights = iterative_process.get_model_weights centralized_validation_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=stackoverflow_validation, loss_builder=loss_builder, metrics_builder=metrics_builder) def validation_fn(server_state, round_num): del round_num return centralized_validation_fn( iterative_process.get_model_weights(server_state)) centralized_test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test), loss_builder=loss_builder, metrics_builder=metrics_builder) def test_fn(server_state): return centralized_test_fn( iterative_process.get_model_weights(server_state)) return training_specs.RunnerSpec(iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)
def configure_training(task_spec: training_specs.TaskSpec, eval_spec: Optional[training_specs.EvalSpec] = None, model: str = 'cnn') -> training_specs.RunnerSpec: """Configures training for the EMNIST character recognition task. This method will load and pre-process datasets and construct a model used for the task. It then uses `iterative_process_builder` to create an iterative process compatible with `federated_research.utils.training_loop`. Args: task_spec: A `TaskSpec` class for creating federated training tasks. eval_spec: An `EvalSpec` class for configuring federated evaluation. If set to None, centralized evaluation is used for validation and testing instead. model: A string specifying the model used for character recognition. Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely connected 2-layer model (respectively). Returns: A `RunnerSpec` containing attributes used for running the newly created federated task. """ emnist_task = 'digit_recognition' emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data( only_digits=False) train_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=task_spec.client_epochs_per_round, batch_size=task_spec.client_batch_size, emnist_task=emnist_task) input_spec = train_preprocess_fn.type_signature.result.element if model == 'cnn': model_builder = functools.partial( emnist_models.create_conv_dropout_model, only_digits=False) elif model == '2nn': model_builder = functools.partial( emnist_models.create_two_hidden_layer_model, only_digits=False) else: raise ValueError( 'Cannot handle model flag [{!s}], must be one of {!s}.'.format( model, EMNIST_MODELS)) loss_builder = tf.keras.losses.SparseCategoricalCrossentropy metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()] def tff_model_fn() -> tff.learning.Model: return tff.learning.from_keras_model( keras_model=model_builder(), input_spec=input_spec, loss=loss_builder(), metrics=metrics_builder()) iterative_process = task_spec.iterative_process_builder(tff_model_fn) clients_per_train_round = min(task_spec.clients_per_round, TOTAL_NUM_TRAIN_CLIENTS) if hasattr(emnist_train, 'dataset_computation'): @tff.tf_computation(tf.string) def build_train_dataset_from_client_id(client_id): client_dataset = emnist_train.dataset_computation(client_id) return train_preprocess_fn(client_dataset) training_process = tff.simulation.compose_dataset_computation_with_iterative_process( build_train_dataset_from_client_id, iterative_process) client_ids_fn = training_utils.build_sample_fn( emnist_train.client_ids, size=clients_per_train_round, replace=False, random_seed=task_spec.sampling_random_seed) # We convert the output to a list (instead of an np.ndarray) so that it can # be used as input to the iterative process. client_sampling_fn = lambda x: list(client_ids_fn(x)) else: training_process = tff.simulation.compose_dataset_computation_with_iterative_process( train_preprocess_fn, iterative_process) client_sampling_fn = training_utils.build_client_datasets_fn( dataset=emnist_train, clients_per_round=clients_per_train_round, random_seed=task_spec.sampling_random_seed) training_process.get_model_weights = iterative_process.get_model_weights if eval_spec: if eval_spec.clients_per_validation_round is None: clients_per_validation_round = TOTAL_NUM_TEST_CLIENTS else: clients_per_validation_round = min(eval_spec.clients_per_validation_round, TOTAL_NUM_TEST_CLIENTS) if eval_spec.clients_per_test_round is None: clients_per_test_round = TOTAL_NUM_TEST_CLIENTS else: clients_per_test_round = min(eval_spec.clients_per_test_round, TOTAL_NUM_TEST_CLIENTS) test_preprocess_fn = emnist_dataset.create_preprocess_fn( num_epochs=1, batch_size=eval_spec.client_batch_size, shuffle_buffer_size=1, emnist_task=emnist_task) emnist_test = emnist_test.preprocess(test_preprocess_fn) def eval_metrics_builder(): return [ tf.keras.metrics.SparseCategoricalCrossentropy(), tf.keras.metrics.SparseCategoricalAccuracy() ] federated_eval_fn = training_utils.build_federated_evaluate_fn( model_builder=model_builder, metrics_builder=eval_metrics_builder) validation_client_sampling_fn = training_utils.build_client_datasets_fn( emnist_test, clients_per_validation_round, random_seed=eval_spec.sampling_random_seed) test_client_sampling_fn = training_utils.build_client_datasets_fn( emnist_test, clients_per_test_round, random_seed=eval_spec.sampling_random_seed) def validation_fn(model_weights, round_num): validation_clients = validation_client_sampling_fn(round_num) return federated_eval_fn(model_weights, validation_clients) def test_fn(model_weights): # We fix the round number to get deterministic behavior test_round_num = 0 test_clients = test_client_sampling_fn(test_round_num) return federated_eval_fn(model_weights, test_clients) else: _, central_emnist_test = emnist_dataset.get_centralized_datasets( only_digits=False, emnist_task=emnist_task) test_fn = training_utils.build_centralized_evaluate_fn( eval_dataset=central_emnist_test, model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: test_fn(model_weights) return training_specs.RunnerSpec( iterative_process=training_process, client_datasets_fn=client_sampling_fn, validation_fn=validation_fn, test_fn=test_fn)