Exemplo n.º 1
0
def configure_training(task_spec: training_specs.TaskSpec,
                       model: str = 'cnn') -> training_specs.RunnerSpec:
  """Configures training for the EMNIST character recognition task.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    model: A string specifying the model used for character recognition. Can be
      one of `cnn` and `2nn`, corresponding to a CNN model and a densely
      connected 2-layer model (respectively).

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """
  emnist_task = 'digit_recognition'
  emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False)
  _, emnist_test = emnist_dataset.get_centralized_datasets(
      only_digits=False, emnist_task=emnist_task)

  train_preprocess_fn = emnist_dataset.create_preprocess_fn(
      num_epochs=task_spec.client_epochs_per_round,
      batch_size=task_spec.client_batch_size,
      emnist_task=emnist_task)

  input_spec = train_preprocess_fn.type_signature.result.element

  if model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=False)
  elif model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model, only_digits=False)
  else:
    raise ValueError(
        'Cannot handle model flag [{!s}], must be one of {!s}.'.format(
            model, EMNIST_MODELS))

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  iterative_process = task_spec.iterative_process_builder(tff_model_fn)

  @tff.tf_computation(tf.string)
  def build_train_dataset_from_client_id(client_id):
    client_dataset = emnist_train.dataset_computation(client_id)
    return train_preprocess_fn(client_dataset)

  training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
      build_train_dataset_from_client_id, iterative_process)
  client_ids_fn = tff.simulation.build_uniform_sampling_fn(
      emnist_train.client_ids,
      size=task_spec.clients_per_round,
      replace=False,
      random_seed=task_spec.client_datasets_random_seed)
  # We convert the output to a list (instead of an np.ndarray) so that it can
  # be used as input to the iterative process.
  client_sampling_fn = lambda x: list(client_ids_fn(x))

  training_process.get_model_weights = iterative_process.get_model_weights

  evaluate_fn = tff.learning.build_federated_evaluation(tff_model_fn)

  def test_fn(state):
    return evaluate_fn(
        iterative_process.get_model_weights(state), [emnist_test])

  def validation_fn(state, round_num):
    del round_num
    return evaluate_fn(
        iterative_process.get_model_weights(state), [emnist_test])

  return training_specs.RunnerSpec(
      iterative_process=training_process,
      client_datasets_fn=client_sampling_fn,
      validation_fn=validation_fn,
      test_fn=test_fn)
Exemplo n.º 2
0
def configure_training(
    task_spec: training_specs.TaskSpec) -> training_specs.RunnerSpec:
  """Configures training for the EMNIST autoencoder task.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """

  emnist_task = 'autoencoder'
  emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False)
  _, emnist_test = emnist_dataset.get_centralized_datasets(
      only_digits=False, emnist_task=emnist_task)

  train_preprocess_fn = emnist_dataset.create_preprocess_fn(
      num_epochs=task_spec.client_epochs_per_round,
      batch_size=task_spec.client_batch_size,
      emnist_task=emnist_task)

  input_spec = train_preprocess_fn.type_signature.result.element

  model_builder = emnist_ae_models.create_autoencoder_model
  loss_builder = functools.partial(
      tf.keras.losses.MeanSquaredError, reduction=tf.keras.losses.Reduction.SUM)
  metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  iterative_process = task_spec.iterative_process_builder(tff_model_fn)

  if hasattr(emnist_train, 'dataset_computation'):

    @tff.tf_computation(tf.string)
    def build_train_dataset_from_client_id(client_id):
      client_dataset = emnist_train.dataset_computation(client_id)
      return train_preprocess_fn(client_dataset)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        build_train_dataset_from_client_id, iterative_process)
    client_ids_fn = training_utils.build_sample_fn(
        emnist_train.client_ids,
        size=task_spec.clients_per_round,
        replace=False,
        random_seed=task_spec.sampling_random_seed)
    # We convert the output to a list (instead of an np.ndarray) so that it can
    # be used as input to the iterative process.
    client_sampling_fn = lambda x: list(client_ids_fn(x))

  else:
    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        train_preprocess_fn, iterative_process)
    client_sampling_fn = training_utils.build_client_datasets_fn(
        dataset=emnist_train,
        clients_per_round=task_spec.clients_per_round,
        random_seed=task_spec.sampling_random_seed)

  training_process.get_model_weights = iterative_process.get_model_weights

  test_fn = training_utils.build_centralized_evaluate_fn(
      eval_dataset=emnist_test,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  validation_fn = lambda model_weights, round_num: test_fn(model_weights)

  return training_specs.RunnerSpec(
      iterative_process=training_process,
      client_datasets_fn=client_sampling_fn,
      validation_fn=validation_fn,
      test_fn=test_fn)
Exemplo n.º 3
0
def configure_training(
        task_spec: training_specs.TaskSpec,
        crop_size: int = 24,
        distort_train_images: bool = True) -> training_specs.RunnerSpec:
    """Configures training for the CIFAR-100 classification task.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    crop_size: An optional integer representing the resulting size of input
      images after preprocessing.
    distort_train_images: A boolean indicating whether to distort training
      images during preprocessing via random crops, as opposed to simply
      resizing the image.

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """
    crop_shape = (crop_size, crop_size, 3)

    cifar_train, _ = tff.simulation.datasets.cifar100.load_data()
    _, cifar_test = cifar100_dataset.get_centralized_datasets(
        train_batch_size=task_spec.client_batch_size, crop_shape=crop_shape)

    train_preprocess_fn = cifar100_dataset.create_preprocess_fn(
        num_epochs=task_spec.client_epochs_per_round,
        batch_size=task_spec.client_batch_size,
        crop_shape=crop_shape,
        distort_image=distort_train_images)
    input_spec = train_preprocess_fn.type_signature.result.element

    model_builder = functools.partial(resnet_models.create_resnet18,
                                      input_shape=crop_shape,
                                      num_classes=NUM_CLASSES)

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    iterative_process = task_spec.iterative_process_builder(tff_model_fn)

    @tff.tf_computation(tf.string)
    def build_train_dataset_from_client_id(client_id):
        client_dataset = cifar_train.dataset_computation(client_id)
        return train_preprocess_fn(client_dataset)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        build_train_dataset_from_client_id, iterative_process)
    client_ids_fn = training_utils.build_sample_fn(
        cifar_train.client_ids,
        size=task_spec.clients_per_round,
        replace=False,
        random_seed=task_spec.client_datasets_random_seed)
    # We convert the output to a list (instead of an np.ndarray) so that it can
    # be used as input to the iterative process.
    client_sampling_fn = lambda x: list(client_ids_fn(x))

    training_process.get_model_weights = iterative_process.get_model_weights

    centralized_eval_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=cifar_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    def test_fn(state):
        return centralized_eval_fn(iterative_process.get_model_weights(state))

    def validation_fn(state, round_num):
        del round_num
        return test_fn(state)

    return training_specs.RunnerSpec(iterative_process=training_process,
                                     client_datasets_fn=client_sampling_fn,
                                     validation_fn=validation_fn,
                                     test_fn=test_fn)
Exemplo n.º 4
0
def configure_training(task_spec: training_specs.TaskSpec,
                       sequence_length: int = 80) -> training_specs.RunnerSpec:
    """Configures training for the Shakespeare next-character prediction task.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    sequence_length: An int specifying the length of the character sequences
      used for prediction.

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """

    shakespeare_train, _ = tff.simulation.datasets.shakespeare.load_data()
    _, shakespeare_test = shakespeare_dataset.get_centralized_datasets(
        sequence_length=sequence_length)

    train_preprocess_fn = shakespeare_dataset.create_preprocess_fn(
        num_epochs=task_spec.client_epochs_per_round,
        batch_size=task_spec.client_batch_size,
        sequence_length=sequence_length)
    input_spec = train_preprocess_fn.type_signature.result.element

    model_builder = functools.partial(create_shakespeare_model,
                                      sequence_length=sequence_length)
    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    iterative_process = task_spec.iterative_process_builder(tff_model_fn)

    if hasattr(shakespeare_train, 'dataset_computation'):

        @tff.tf_computation(tf.string)
        def build_train_dataset_from_client_id(client_id):
            client_dataset = shakespeare_train.dataset_computation(client_id)
            return train_preprocess_fn(client_dataset)

        training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
            build_train_dataset_from_client_id, iterative_process)
        client_ids_fn = training_utils.build_sample_fn(
            shakespeare_train.client_ids,
            size=task_spec.clients_per_round,
            replace=False,
            random_seed=task_spec.client_datasets_random_seed)
        # We convert the output to a list (instead of an np.ndarray) so that it can
        # be used as input to the iterative process.
        client_sampling_fn = lambda x: list(client_ids_fn(x))

    else:
        training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
            train_preprocess_fn, iterative_process)
        client_sampling_fn = training_utils.build_client_datasets_fn(
            dataset=shakespeare_train,
            clients_per_round=task_spec.clients_per_round,
            random_seed=task_spec.client_datasets_random_seed)

    training_process.get_model_weights = iterative_process.get_model_weights

    test_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=shakespeare_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    validation_fn = lambda model_weights, round_num: test_fn(model_weights)

    return training_specs.RunnerSpec(iterative_process=training_process,
                                     client_datasets_fn=client_sampling_fn,
                                     validation_fn=validation_fn,
                                     test_fn=test_fn)
def configure_training(
        task_spec: training_specs.TaskSpec,
        vocab_size: int = 10000,
        num_oov_buckets: int = 1,
        sequence_length: int = 20,
        max_elements_per_user: int = 1000,
        num_validation_examples: int = 10000,
        embedding_size: int = 96,
        latent_size: int = 670,
        num_layers: int = 1,
        shared_embedding: bool = False) -> training_specs.RunnerSpec:
    """Configures training for Stack Overflow next-word prediction.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    vocab_size: Integer dictating the number of most frequent words to use in
      the vocabulary.
    num_oov_buckets: The number of out-of-vocabulary buckets to use.
    sequence_length: The maximum number of words to take for each sequence.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    num_validation_examples: The number of test examples to use for validation.
    embedding_size: The dimension of the word embedding layer.
    latent_size: The dimension of the latent units in the recurrent layers.
    num_layers: The number of stacked recurrent layers to use.
    shared_embedding: Boolean indicating whether to tie input and output
      embeddings.

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=vocab_size,
        num_oov_buckets=num_oov_buckets,
        embedding_size=embedding_size,
        latent_size=latent_size,
        num_layers=num_layers,
        shared_embedding=shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        vocab_size, num_oov_buckets)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
        ]

    train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data()

    # TODO(b/161914546): consider moving evaluation to use
    # `tff.learning.build_federated_evaluation` to get metrics over client
    # distributions, as well as the example weight means from this centralized
    # evaluation.
    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=vocab_size,
        max_sequence_length=sequence_length,
        num_validation_examples=num_validation_examples,
        num_oov_buckets=num_oov_buckets)

    train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn(
        vocab=stackoverflow_word_prediction.create_vocab(vocab_size),
        num_oov_buckets=num_oov_buckets,
        client_batch_size=task_spec.client_batch_size,
        client_epochs_per_round=task_spec.client_epochs_per_round,
        max_sequence_length=sequence_length,
        max_elements_per_client=max_elements_per_user)

    input_spec = train_dataset_preprocess_comp.type_signature.result.element

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    iterative_process = task_spec.iterative_process_builder(tff_model_fn)

    @tff.tf_computation(tf.string)
    def train_dataset_computation(client_id):
        client_train_data = train_clientdata.dataset_computation(client_id)
        return train_dataset_preprocess_comp(client_train_data)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        train_dataset_computation, iterative_process)
    client_ids_fn = training_utils.build_sample_fn(
        train_clientdata.client_ids,
        size=task_spec.clients_per_round,
        replace=False,
        random_seed=task_spec.client_datasets_random_seed)
    # We convert the output to a list (instead of an np.ndarray) so that it can
    # be used as input to the iterative process.
    client_sampling_fn = lambda x: list(client_ids_fn(x))

    training_process.get_model_weights = iterative_process.get_model_weights

    centralized_validation_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    def validation_fn(server_state, round_num):
        del round_num
        return centralized_validation_fn(
            iterative_process.get_model_weights(server_state))

    centralized_test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    def test_fn(server_state):
        return centralized_test_fn(
            iterative_process.get_model_weights(server_state))

    return training_specs.RunnerSpec(iterative_process=training_process,
                                     client_datasets_fn=client_sampling_fn,
                                     validation_fn=validation_fn,
                                     test_fn=test_fn)
Exemplo n.º 6
0
def configure_training(
        task_spec: training_specs.TaskSpec,
        vocab_tokens_size: int = 10000,
        vocab_tags_size: int = 500,
        max_elements_per_user: int = 1000,
        num_validation_examples: int = 10000) -> training_specs.RunnerSpec:
    """Configures training for the Stack Overflow tag prediction task.

  This tag prediction is performed via multi-class one-versus-rest logistic
  regression. This method will load and pre-process datasets and construct a
  model used for the task. It then uses `iterative_process_builder` to create an
  iterative process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    vocab_tokens_size: Integer dictating the number of most frequent words to
      use in the vocabulary.
    vocab_tags_size: Integer dictating the number of most frequent tags to use
      in the label creation.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    num_validation_examples: The number of test examples to use for validation.

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """

    stackoverflow_train, _, _ = tff.simulation.datasets.stackoverflow.load_data(
    )

    _, stackoverflow_validation, stackoverflow_test = stackoverflow_tag_prediction.get_centralized_datasets(
        train_batch_size=task_spec.client_batch_size,
        word_vocab_size=vocab_tokens_size,
        tag_vocab_size=vocab_tags_size,
        num_validation_examples=num_validation_examples)

    word_vocab = stackoverflow_tag_prediction.create_word_vocab(
        vocab_tokens_size)
    tag_vocab = stackoverflow_tag_prediction.create_tag_vocab(vocab_tags_size)

    train_preprocess_fn = stackoverflow_tag_prediction.create_preprocess_fn(
        word_vocab=word_vocab,
        tag_vocab=tag_vocab,
        client_batch_size=task_spec.client_batch_size,
        client_epochs_per_round=task_spec.client_epochs_per_round,
        max_elements_per_client=max_elements_per_user)
    input_spec = train_preprocess_fn.type_signature.result.element

    model_builder = functools.partial(
        stackoverflow_lr_models.create_logistic_model,
        vocab_tokens_size=vocab_tokens_size,
        vocab_tags_size=vocab_tags_size)

    loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy,
                                     from_logits=False,
                                     reduction=tf.keras.losses.Reduction.SUM)

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    iterative_process = task_spec.iterative_process_builder(tff_model_fn)

    @tff.tf_computation(tf.string)
    def build_train_dataset_from_client_id(client_id):
        client_dataset = stackoverflow_train.dataset_computation(client_id)
        return train_preprocess_fn(client_dataset)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        build_train_dataset_from_client_id, iterative_process)
    client_ids_fn = training_utils.build_sample_fn(
        stackoverflow_train.client_ids,
        size=task_spec.clients_per_round,
        replace=False,
        random_seed=task_spec.client_datasets_random_seed)
    # We convert the output to a list (instead of an np.ndarray) so that it can
    # be used as input to the iterative process.
    client_sampling_fn = lambda x: list(client_ids_fn(x))

    training_process.get_model_weights = iterative_process.get_model_weights

    centralized_validation_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=stackoverflow_validation,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    def validation_fn(server_state, round_num):
        del round_num
        return centralized_validation_fn(
            iterative_process.get_model_weights(server_state))

    centralized_test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    def test_fn(server_state):
        return centralized_test_fn(
            iterative_process.get_model_weights(server_state))

    return training_specs.RunnerSpec(iterative_process=training_process,
                                     client_datasets_fn=client_sampling_fn,
                                     validation_fn=validation_fn,
                                     test_fn=test_fn)
Exemplo n.º 7
0
def configure_training(task_spec: training_specs.TaskSpec,
                       eval_spec: Optional[training_specs.EvalSpec] = None,
                       model: str = 'cnn') -> training_specs.RunnerSpec:
  """Configures training for the EMNIST character recognition task.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    eval_spec: An `EvalSpec` class for configuring federated evaluation. If set
      to None, centralized evaluation is used for validation and testing
      instead.
    model: A string specifying the model used for character recognition. Can be
      one of `cnn` and `2nn`, corresponding to a CNN model and a densely
      connected 2-layer model (respectively).

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """
  emnist_task = 'digit_recognition'

  emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
      only_digits=False)

  train_preprocess_fn = emnist_dataset.create_preprocess_fn(
      num_epochs=task_spec.client_epochs_per_round,
      batch_size=task_spec.client_batch_size,
      emnist_task=emnist_task)

  input_spec = train_preprocess_fn.type_signature.result.element

  if model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=False)
  elif model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model, only_digits=False)
  else:
    raise ValueError(
        'Cannot handle model flag [{!s}], must be one of {!s}.'.format(
            model, EMNIST_MODELS))

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  iterative_process = task_spec.iterative_process_builder(tff_model_fn)

  clients_per_train_round = min(task_spec.clients_per_round,
                                TOTAL_NUM_TRAIN_CLIENTS)

  if hasattr(emnist_train, 'dataset_computation'):

    @tff.tf_computation(tf.string)
    def build_train_dataset_from_client_id(client_id):
      client_dataset = emnist_train.dataset_computation(client_id)
      return train_preprocess_fn(client_dataset)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        build_train_dataset_from_client_id, iterative_process)
    client_ids_fn = training_utils.build_sample_fn(
        emnist_train.client_ids,
        size=clients_per_train_round,
        replace=False,
        random_seed=task_spec.sampling_random_seed)
    # We convert the output to a list (instead of an np.ndarray) so that it can
    # be used as input to the iterative process.
    client_sampling_fn = lambda x: list(client_ids_fn(x))

  else:
    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        train_preprocess_fn, iterative_process)
    client_sampling_fn = training_utils.build_client_datasets_fn(
        dataset=emnist_train,
        clients_per_round=clients_per_train_round,
        random_seed=task_spec.sampling_random_seed)

  training_process.get_model_weights = iterative_process.get_model_weights

  if eval_spec:

    if eval_spec.clients_per_validation_round is None:
      clients_per_validation_round = TOTAL_NUM_TEST_CLIENTS
    else:
      clients_per_validation_round = min(eval_spec.clients_per_validation_round,
                                         TOTAL_NUM_TEST_CLIENTS)

    if eval_spec.clients_per_test_round is None:
      clients_per_test_round = TOTAL_NUM_TEST_CLIENTS
    else:
      clients_per_test_round = min(eval_spec.clients_per_test_round,
                                   TOTAL_NUM_TEST_CLIENTS)

    test_preprocess_fn = emnist_dataset.create_preprocess_fn(
        num_epochs=1,
        batch_size=eval_spec.client_batch_size,
        shuffle_buffer_size=1,
        emnist_task=emnist_task)
    emnist_test = emnist_test.preprocess(test_preprocess_fn)

    def eval_metrics_builder():
      return [
          tf.keras.metrics.SparseCategoricalCrossentropy(),
          tf.keras.metrics.SparseCategoricalAccuracy()
      ]

    federated_eval_fn = training_utils.build_federated_evaluate_fn(
        model_builder=model_builder, metrics_builder=eval_metrics_builder)

    validation_client_sampling_fn = training_utils.build_client_datasets_fn(
        emnist_test,
        clients_per_validation_round,
        random_seed=eval_spec.sampling_random_seed)
    test_client_sampling_fn = training_utils.build_client_datasets_fn(
        emnist_test,
        clients_per_test_round,
        random_seed=eval_spec.sampling_random_seed)

    def validation_fn(model_weights, round_num):
      validation_clients = validation_client_sampling_fn(round_num)
      return federated_eval_fn(model_weights, validation_clients)

    def test_fn(model_weights):
      # We fix the round number to get deterministic behavior
      test_round_num = 0
      test_clients = test_client_sampling_fn(test_round_num)
      return federated_eval_fn(model_weights, test_clients)

  else:
    _, central_emnist_test = emnist_dataset.get_centralized_datasets(
        only_digits=False, emnist_task=emnist_task)

    test_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=central_emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    validation_fn = lambda model_weights, round_num: test_fn(model_weights)

  return training_specs.RunnerSpec(
      iterative_process=training_process,
      client_datasets_fn=client_sampling_fn,
      validation_fn=validation_fn,
      test_fn=test_fn)