def test_build_to_ids_fn_truncates(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 1
     bos = stackoverflow_word_prediction.get_special_tokens(len(vocab)).bos
     to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn(
         vocab, max_seq_len)
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     self.assertAllEqual(self.evaluate(processed), [bos, 1])
 def test_oov_token_correct(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     num_oov_buckets = 2
     to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn(
         vocab, max_seq_len, num_oov_buckets=num_oov_buckets)
     oov_tokens = stackoverflow_word_prediction.get_special_tokens(
         len(vocab), num_oov_buckets=num_oov_buckets).oov
     data = {'tokens': 'A B D'}
     processed = to_ids_fn(data)
     self.assertLen(oov_tokens, num_oov_buckets)
     self.assertIn(self.evaluate(processed)[3], oov_tokens)
 def test_build_to_ids_fn_embeds_all_vocab(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     special_tokens = stackoverflow_word_prediction.get_special_tokens(
         len(vocab))
     bos = special_tokens.bos
     eos = special_tokens.eos
     to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn(
         vocab, max_seq_len)
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
 def test_pad_token_correct(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn(
       vocab, max_seq_len)
   special_tokens = stackoverflow_word_prediction.get_special_tokens(
       len(vocab))
   pad, bos, eos = special_tokens.pad, special_tokens.bos, special_tokens.eos
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   batched_ds = tf.data.Dataset.from_tensor_slices([processed]).padded_batch(
       1, padded_shapes=[6])
   sample_elem = next(iter(batched_ds))
   self.assertAllEqual(self.evaluate(sample_elem), [[bos, 1, 2, 3, eos, pad]])
예제 #5
0
def _get_stackoverflow_metrics(vocab_size, num_oov_buckets):
  """Metrics for stackoverflow dataset."""
  special_tokens = stackoverflow_dataset.get_special_tokens(
      vocab_size, num_oov_buckets)
  pad_token = special_tokens.pad
  oov_tokens = special_tokens.oov
  eos_token = special_tokens.eos
  return [
      keras_metrics.MaskedCategoricalAccuracy(
          name='accuracy_with_oov', masked_tokens=[pad_token]),
      keras_metrics.MaskedCategoricalAccuracy(
          name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens),
      keras_metrics.MaskedCategoricalAccuracy(
          name='accuracy_no_oov_or_eos',
          masked_tokens=[pad_token, eos_token] + oov_tokens),
  ]
예제 #6
0
def run_centralized(optimizer: tf.keras.optimizers.Optimizer,
                    num_epochs: int,
                    batch_size: int,
                    decay_epochs: Optional[int] = None,
                    lr_decay: Optional[float] = None,
                    vocab_size: int = 10000,
                    num_oov_buckets: int = 1,
                    d_embed: int = 96,
                    d_model: int = 512,
                    d_hidden: int = 2048,
                    num_heads: int = 8,
                    num_layers: int = 1,
                    max_position_encoding: int = 1000,
                    dropout: float = 0.1,
                    num_validation_examples: int = 10000,
                    sequence_length: int = 20,
                    experiment_name: str = 'centralized_stackoverflow',
                    root_output_dir: str = '/tmp/fedopt_guide',
                    hparams_dict: Optional[Mapping[str, Any]] = None,
                    max_batches: Optional[int] = None):
    """Trains an Transformer on the Stack Overflow next word prediction task.

  Args:
    optimizer: A `tf.keras.optimizers.Optimizer` used to perform training.
    num_epochs: The number of training epochs.
    batch_size: The batch size, used for train, validation, and test.
    decay_epochs: The number of epochs of training before decaying the learning
      rate. If None, no decay occurs.
    lr_decay: The amount to decay the learning rate by after `decay_epochs`
      training epochs have occurred.
    vocab_size: Vocab size for normal tokens.
    num_oov_buckets: Number of out of vocabulary buckets.
    d_embed: Dimension of the token embeddings.
    d_model: Dimension of features of MultiHeadAttention layers.
    d_hidden: Dimension of hidden layers of the FFN.
    num_heads: Number of attention heads.
    num_layers: Number of Transformer blocks.
    max_position_encoding: Maximum number of positions for position embeddings.
    dropout: Dropout rate.
    num_validation_examples: The number of test examples to use for validation.
    sequence_length: The maximum number of words to take for each sequence.
    experiment_name: The name of the experiment. Part of the output directory.
    root_output_dir: The top-level output directory for experiment runs. The
      `experiment_name` argument will be appended, and the directory will
      contain tensorboard logs, metrics written as CSVs, and a CSV of
      hyperparameter choices (if `hparams_dict` is used).
    hparams_dict: A mapping with string keys representing the hyperparameters
      and their values. If not None, this is written to CSV.
    max_batches: If set to a positive integer, datasets are capped to at most
      that many batches. If set to None or a nonpositive integer, the full
      datasets are used.
  """

    train_dataset, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size,
        sequence_length,
        train_batch_size=batch_size,
        num_validation_examples=num_validation_examples,
        num_oov_buckets=num_oov_buckets,
    )

    if max_batches and max_batches >= 1:
        train_dataset = train_dataset.take(max_batches)
        validation_dataset = validation_dataset.take(max_batches)
        test_dataset = test_dataset.take(max_batches)

    model = transformer_models.create_transformer_lm(
        vocab_size=vocab_size,
        num_oov_buckets=num_oov_buckets,
        d_embed=d_embed,
        d_model=d_model,
        d_hidden=d_hidden,
        num_heads=num_heads,
        num_layers=num_layers,
        max_position_encoding=max_position_encoding,
        dropout=dropout,
        name='stackoverflow-transformer')

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        vocab_size=vocab_size, num_oov_buckets=num_oov_buckets)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=optimizer,
        metrics=[
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
        ])

    centralized_training_loop.run(keras_model=model,
                                  train_dataset=train_dataset,
                                  validation_dataset=validation_dataset,
                                  test_dataset=test_dataset,
                                  experiment_name=experiment_name,
                                  root_output_dir=root_output_dir,
                                  num_epochs=num_epochs,
                                  hparams_dict=hparams_dict,
                                  decay_epochs=decay_epochs,
                                  lr_decay=lr_decay)
def run_centralized(optimizer: tf.keras.optimizers.Optimizer,
                    experiment_name: str,
                    root_output_dir: str,
                    num_epochs: int,
                    batch_size: int,
                    decay_epochs: Optional[int] = None,
                    lr_decay: Optional[float] = None,
                    hparams_dict: Optional[Mapping[str, Any]] = None,
                    vocab_size: Optional[int] = 10000,
                    num_oov_buckets: Optional[int] = 1,
                    sequence_length: Optional[int] = 20,
                    num_validation_examples: Optional[int] = 10000,
                    embedding_size: Optional[int] = 96,
                    latent_size: Optional[int] = 670,
                    num_layers: Optional[int] = 1,
                    shared_embedding: Optional[bool] = False,
                    max_batches: Optional[int] = None,
                    cache_dir: Optional[str] = None):
    """Trains an RNN on the Stack Overflow next word prediction task.

  Args:
    optimizer: A `tf.keras.optimizers.Optimizer` used to perform training.
    experiment_name: The name of the experiment. Part of the output directory.
    root_output_dir: The top-level output directory for experiment runs. The
      `experiment_name` argument will be appended, and the directory will
      contain tensorboard logs, metrics written as CSVs, and a CSV of
      hyperparameter choices (if `hparams_dict` is used).
    num_epochs: The number of training epochs.
    batch_size: The batch size, used for train, validation, and test.
    decay_epochs: The number of epochs of training before decaying the learning
      rate. If None, no decay occurs.
    lr_decay: The amount to decay the learning rate by after `decay_epochs`
      training epochs have occurred.
    hparams_dict: A mapping with string keys representing the hyperparameters
      and their values. If not None, this is written to CSV.
    vocab_size: Integer dictating the number of most frequent words to use in
      the vocabulary.
    num_oov_buckets: The number of out-of-vocabulary buckets to use.
    sequence_length: The maximum number of words to take for each sequence.
    num_validation_examples: The number of test examples to use for validation.
    embedding_size: The dimension of the word embedding layer.
    latent_size: The dimension of the latent units in the recurrent layers.
    num_layers: The number of stacked recurrent layers to use.
    shared_embedding: Boolean indicating whether to tie input and output
      embeddings.
    max_batches: If set to a positive integer, datasets are capped to at most
      that many batches. If set to None or a nonpositive integer, the full
      datasets are used.
  """

    train_dataset, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=vocab_size,
        max_sequence_length=sequence_length,
        train_batch_size=batch_size,
        num_validation_examples=num_validation_examples,
        num_oov_buckets=num_oov_buckets,
        cache_dir=cache_dir)

    if max_batches and max_batches >= 1:
        train_dataset = train_dataset.take(max_batches)
        validation_dataset = validation_dataset.take(max_batches)
        test_dataset = test_dataset.take(max_batches)

    model = stackoverflow_models.create_recurrent_model(
        vocab_size=vocab_size,
        num_oov_buckets=num_oov_buckets,
        name='stackoverflow-lstm',
        embedding_size=embedding_size,
        latent_size=latent_size,
        num_layers=num_layers,
        shared_embedding=shared_embedding)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        vocab_size=vocab_size, num_oov_buckets=num_oov_buckets)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=optimizer,
        metrics=[
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
        ])

    centralized_training_loop.run(keras_model=model,
                                  train_dataset=train_dataset,
                                  validation_dataset=validation_dataset,
                                  test_dataset=test_dataset,
                                  experiment_name=experiment_name,
                                  root_output_dir=root_output_dir,
                                  num_epochs=num_epochs,
                                  hparams_dict=hparams_dict,
                                  decay_epochs=decay_epochs,
                                  lr_decay=lr_decay)
예제 #8
0
def run_federated(
        iterative_process_builder: Callable[...,
                                            tff.templates.IterativeProcess],
        client_epochs_per_round: int,
        client_batch_size: int,
        clients_per_round: int,
        client_datasets_random_seed: Optional[int] = None,
        vocab_size: Optional[int] = 10000,
        num_oov_buckets: Optional[int] = 1,
        sequence_length: Optional[int] = 20,
        max_elements_per_user: Optional[int] = 1000,
        num_validation_examples: Optional[int] = 10000,
        embedding_size: Optional[int] = 96,
        latent_size: Optional[int] = 670,
        num_layers: Optional[int] = 1,
        shared_embedding: Optional[bool] = False,
        total_rounds: Optional[int] = 1500,
        experiment_name: Optional[str] = 'federated_so_nwp',
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        **kwargs):
    """Runs an iterative process on the Stack Overflow next word prediction task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, a
      `client_weight_fn` and returns a `tff.templates.IterativeProcess`. The
      `model_fn` must return a `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    vocab_size: Integer dictating the number of most frequent words to use in
      the vocabulary.
    num_oov_buckets: The number of out-of-vocabulary buckets to use.
    sequence_length: The maximum number of words to take for each sequence.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    num_validation_examples: The number of test examples to use for validation.
    embedding_size: The dimension of the word embedding layer.
    latent_size: The dimension of the latent units in the recurrent layers.
    num_layers: The number of stacked recurrent layers to use.
    shared_embedding: Boolean indicating whether to tie input and output
      embeddings.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=vocab_size,
        num_oov_buckets=num_oov_buckets,
        embedding_size=embedding_size,
        latent_size=latent_size,
        num_layers=num_layers,
        shared_embedding=shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        vocab_size, num_oov_buckets)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
        ]

    train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data()

    # TODO(b/161914546): consider moving evaluation to use
    # `tff.learning.build_federated_evaluation` to get metrics over client
    # distributions, as well as the example weight means from this centralized
    # evaluation.
    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=vocab_size,
        max_sequence_length=sequence_length,
        num_validation_examples=num_validation_examples,
        num_oov_buckets=num_oov_buckets)

    train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn(
        vocab=stackoverflow_word_prediction.create_vocab(vocab_size),
        num_oov_buckets=num_oov_buckets,
        client_batch_size=client_batch_size,
        client_epochs_per_round=client_epochs_per_round,
        max_sequence_length=sequence_length,
        max_elements_per_client=max_elements_per_user)

    input_spec = train_dataset_preprocess_comp.type_signature.result.element

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    iterative_process = iterative_process_builder(
        tff_model_fn, client_weight_fn=client_weight_fn)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        train_dataset_preprocess_comp, iterative_process)

    training_process.get_model_weights = iterative_process.get_model_weights

    client_datasets_fn = training_utils.build_client_datasets_fn(
        dataset=train_clientdata,
        clients_per_round=clients_per_round,
        random_seed=client_datasets_random_seed)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights)

    test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process=training_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=validation_fn,
                      test_fn=test_fn,
                      total_rounds=total_rounds,
                      experiment_name=experiment_name,
                      root_output_dir=root_output_dir,
                      **kwargs)
def configure_training(
        task_spec: training_specs.TaskSpec,
        vocab_size: int = 10000,
        num_oov_buckets: int = 1,
        sequence_length: int = 20,
        max_elements_per_user: int = 1000,
        num_validation_examples: int = 10000,
        embedding_size: int = 96,
        latent_size: int = 670,
        num_layers: int = 1,
        shared_embedding: bool = False) -> training_specs.RunnerSpec:
    """Configures training for Stack Overflow next-word prediction.

  This method will load and pre-process datasets and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process compatible with `federated_research.utils.training_loop`.

  Args:
    task_spec: A `TaskSpec` class for creating federated training tasks.
    vocab_size: Integer dictating the number of most frequent words to use in
      the vocabulary.
    num_oov_buckets: The number of out-of-vocabulary buckets to use.
    sequence_length: The maximum number of words to take for each sequence.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    num_validation_examples: The number of test examples to use for validation.
    embedding_size: The dimension of the word embedding layer.
    latent_size: The dimension of the latent units in the recurrent layers.
    num_layers: The number of stacked recurrent layers to use.
    shared_embedding: Boolean indicating whether to tie input and output
      embeddings.

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=vocab_size,
        num_oov_buckets=num_oov_buckets,
        embedding_size=embedding_size,
        latent_size=latent_size,
        num_layers=num_layers,
        shared_embedding=shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        vocab_size, num_oov_buckets)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
        ]

    train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data()

    # TODO(b/161914546): consider moving evaluation to use
    # `tff.learning.build_federated_evaluation` to get metrics over client
    # distributions, as well as the example weight means from this centralized
    # evaluation.
    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=vocab_size,
        max_sequence_length=sequence_length,
        num_validation_examples=num_validation_examples,
        num_oov_buckets=num_oov_buckets)

    train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn(
        vocab=stackoverflow_word_prediction.create_vocab(vocab_size),
        num_oov_buckets=num_oov_buckets,
        client_batch_size=task_spec.client_batch_size,
        client_epochs_per_round=task_spec.client_epochs_per_round,
        max_sequence_length=sequence_length,
        max_elements_per_client=max_elements_per_user)

    input_spec = train_dataset_preprocess_comp.type_signature.result.element

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    iterative_process = task_spec.iterative_process_builder(tff_model_fn)

    @tff.tf_computation(tf.string)
    def train_dataset_computation(client_id):
        client_train_data = train_clientdata.dataset_computation(client_id)
        return train_dataset_preprocess_comp(client_train_data)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        train_dataset_computation, iterative_process)
    client_ids_fn = training_utils.build_sample_fn(
        train_clientdata.client_ids,
        size=task_spec.clients_per_round,
        replace=False,
        random_seed=task_spec.client_datasets_random_seed)
    # We convert the output to a list (instead of an np.ndarray) so that it can
    # be used as input to the iterative process.
    client_sampling_fn = lambda x: list(client_ids_fn(x))

    training_process.get_model_weights = iterative_process.get_model_weights

    centralized_validation_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    def validation_fn(server_state, round_num):
        del round_num
        return centralized_validation_fn(
            iterative_process.get_model_weights(server_state))

    centralized_test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    def test_fn(server_state):
        return centralized_test_fn(
            iterative_process.get_model_weights(server_state))

    return training_specs.RunnerSpec(iterative_process=training_process,
                                     client_datasets_fn=client_sampling_fn,
                                     validation_fn=validation_fn,
                                     test_fn=test_fn)
예제 #10
0
def run_federated(iterative_process_builder: Callable[
    ..., tff.templates.IterativeProcess],
                  client_epochs_per_round: int,
                  client_batch_size: int,
                  clients_per_round: int,
                  max_elements_per_user: int,
                  total_rounds: int = 3000,
                  vocab_size: int = 10000,
                  num_oov_buckets: int = 1,
                  sequence_length: int = 20,
                  num_validation_examples: int = 10000,
                  dim_embed: int = 96,
                  dim_model: int = 512,
                  dim_hidden: int = 2048,
                  num_heads: int = 8,
                  num_layers: int = 1,
                  max_position_encoding: int = 1000,
                  dropout: float = 0.1,
                  client_datasets_random_seed: Optional[int] = None,
                  experiment_name: str = 'federated_stackoverflow',
                  root_output_dir: str = '/tmp/fedopt_guide',
                  max_val_test_batches: Optional[int] = None,
                  **kwargs) -> None:
    """Configures training for Stack Overflow next-word prediction.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research/fedopt_guide/training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, a
      `client_weight_fn` and returns a `tff.templates.IterativeProcess`. The
      `model_fn` must return a `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset. This has be to a positive value or -1 (which means that
      all elements are taken for training).
    total_rounds: The number of federated training rounds.
    vocab_size: Integer dictating the number of most frequent words to use in
      the vocabulary.
    num_oov_buckets: The number of out-of-vocabulary buckets to use.
    sequence_length: The maximum number of words to take for each sequence.
    num_validation_examples: The number of test examples to use for validation.
    dim_embed: An integer for the dimension of the token embeddings.
    dim_model: An integer for the dimension of features of MultiHeadAttention
      layers.
    dim_hidden: An integer for the dimension of hidden layers of the FFN.
    num_heads:  An integer for the number of attention heads.
    num_layers: An integer for the number of Transformer blocks.
    max_position_encoding: Maximum number of positions for position embeddings.
    dropout: Dropout rate.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    max_val_test_batches: If set to a positive integer, val and test datasets
      are capped to at most that many batches. If set to None or a nonpositive
      integer, the full datasets are used.
    **kwargs: Additional arguments configuring the training loop. For details on
      supported arguments, see
      `federated_research/fedopt_guide/training_utils.py`.

  Returns:
    A `RunnerSpec` containing attributes used for running the newly created
    federated task.
  """

    train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data()

    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=vocab_size,
        max_sequence_length=sequence_length,
        num_validation_examples=num_validation_examples,
        num_oov_buckets=num_oov_buckets)

    if max_val_test_batches and max_val_test_batches >= 1:
        validation_dataset = validation_dataset.take(max_val_test_batches)
        test_dataset = test_dataset.take(max_val_test_batches)

    model_builder = functools.partial(
        transformer_models.create_transformer_lm,
        vocab_size=vocab_size,
        num_oov_buckets=num_oov_buckets,
        dim_embed=dim_embed,
        dim_model=dim_model,
        dim_hidden=dim_hidden,
        num_heads=num_heads,
        num_layers=num_layers,
        max_position_encoding=max_position_encoding,
        dropout=dropout,
        name='stackoverflow-transformer')

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        vocab_size, num_oov_buckets)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
        ]

    train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn(
        vocab=stackoverflow_word_prediction.create_vocab(vocab_size),
        num_oov_buckets=num_oov_buckets,
        client_batch_size=client_batch_size,
        client_epochs_per_round=client_epochs_per_round,
        max_sequence_length=sequence_length,
        max_elements_per_client=max_elements_per_user)

    input_spec = train_dataset_preprocess_comp.type_signature.result.element

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    iterative_process = iterative_process_builder(
        tff_model_fn, client_weight_fn=client_weight_fn)

    if hasattr(train_clientdata, 'dataset_computation'):

        @tff.tf_computation(tf.string)
        def train_dataset_computation(client_id):
            client_train_data = train_clientdata.dataset_computation(client_id)
            return train_dataset_preprocess_comp(client_train_data)

        training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
            train_dataset_computation, iterative_process)
        client_ids_fn = tff.simulation.build_uniform_sampling_fn(
            train_clientdata.client_ids,
            size=clients_per_round,
            replace=False,
            random_seed=client_datasets_random_seed)
        # We convert the output to a list (instead of an np.ndarray) so that it can
        # be used as input to the iterative process.
        client_sampling_fn = lambda x: list(client_ids_fn(x))
    else:
        training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
            train_dataset_preprocess_comp, iterative_process)
        client_sampling_fn = tff.simulation.build_uniform_client_sampling_fn(
            dataset=train_clientdata,
            clients_per_round=clients_per_round,
            random_seed=client_datasets_random_seed)

    training_process.get_model_weights = iterative_process.get_model_weights

    evaluate_fn = tff.learning.build_federated_evaluation(tff_model_fn)

    def validation_fn(model_weights, round_num):
        del round_num
        return evaluate_fn(model_weights, [validation_dataset])

    def test_fn(model_weights):
        return evaluate_fn(model_weights,
                           [validation_dataset.concatenate(test_dataset)])

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process=training_process,
                      train_client_datasets_fn=client_sampling_fn,
                      evaluation_fn=validation_fn,
                      test_fn=test_fn,
                      total_rounds=total_rounds,
                      experiment_name=experiment_name,
                      root_output_dir=root_output_dir,
                      **kwargs)
예제 #11
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tff.backends.native.set_local_execution_context(max_fanout=10)

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        FLAGS.vocab_size)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        ]

    train_dataset, _ = stackoverflow_word_prediction.get_federated_datasets(
        vocab_size=FLAGS.vocab_size,
        train_client_batch_size=FLAGS.client_batch_size,
        train_client_epochs_per_round=FLAGS.client_epochs_per_round,
        max_sequence_length=FLAGS.sequence_length,
        max_elements_per_train_client=FLAGS.max_elements_per_user)
    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=FLAGS.vocab_size,
        max_sequence_length=FLAGS.sequence_length,
        num_validation_examples=FLAGS.num_validation_examples)

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0
    else:

        def client_weight_fn(local_outputs):
            return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=validation_dataset.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_clients_per_round=FLAGS.clients_per_round)

        weights_type = tff.learning.framework.weights_type_from_model(model_fn)
        aggregation_process = tff.utils.build_dp_aggregate_process(
            weights_type.trainable, dp_query)
    else:
        aggregation_process = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn,
        client_optimizer_fn=client_optimizer_fn,
        aggregation_process=aggregation_process)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights)

    test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=validation_fn,
                      test_fn=test_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
예제 #12
0
def run_federated(
    iterative_process_builder: Callable[..., tff.templates.IterativeProcess],
    evaluation_computation_builder: Callable[..., tff.Computation],
    client_batch_size: int,
    clients_per_round: int,
    global_variables_only: bool,
    vocab_size: int = 10000,
    num_oov_buckets: int = 1,
    sequence_length: int = 20,
    max_elements_per_user: int = 1000,
    embedding_size: int = 96,
    latent_size: int = 670,
    num_layers: int = 1,
    total_rounds: int = 1500,
    experiment_name: str = 'federated_so_nwp',
    root_output_dir: str = '/tmp/fed_recon',
    split_dataset_strategy: str = federated_trainer_utils
    .SPLIT_STRATEGY_AGGREGATED,
    split_dataset_proportion: int = 2,
    compose_dataset_computation: bool = False,
    **kwargs):
  """Runs an iterative process on the Stack Overflow next word prediction task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  This model only sends updates for its embeddings corresponding to the most
  common words. Embeddings for out of vocabulary buckets are reconstructed on
  device at the beginning of each round, and destroyed at the end of these
  rounds.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, a
      `loss_fn`, a `metrics_fn`, and a `client_weight_fn`, and returns a
      `tff.templates.IterativeProcess`. The `model_fn` must return a
      `reconstruction_model.ReconstructionModel`. See `federated_trainer.py` for
      an example.
    evaluation_computation_builder: A function that accepts a no-arg `model_fn`,
      a loss_fn`, and a `metrics_fn`, and returns a `tff.Computation` for
      federated reconstruction evaluation. The `model_fn` must return a
      `reconstruction_model.ReconstructionModel`. See `federated_trainer.py` for
      an example.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    global_variables_only: If True, the `ReconstructionModel` contains all model
      variables as global variables. This can be useful for baselines involving
      aggregating all variables.
    vocab_size: Integer dictating the number of most frequent words to use in
      the vocabulary.
    num_oov_buckets: The number of out-of-vocabulary buckets to use.
    sequence_length: The maximum number of words to take for each sequence.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    embedding_size: The dimension of the word embedding layer.
    latent_size: The dimension of the latent units in the recurrent layers.
    num_layers: The number of stacked recurrent layers to use.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    split_dataset_strategy: The method to use to split the data. Must be one of
      `skip`, in which case every `split_dataset_proportion` example is used for
      reconstruction, or `aggregated`, when the first
      1/`split_dataset_proportion` proportion of the examples is used for
      reconstruction.
    split_dataset_proportion: Parameter controlling how much of the data is used
      for reconstruction. If `split_dataset_proportion` is n, then 1 / n of the
      data is used for reconstruction.
    compose_dataset_computation: Whether to compose dataset computation with
      training and evaluation computations. If True, may speed up experiments by
      parallelizing dataset computations in multimachine setups. Not currently
      supported in OSS.
    **kwargs: Additional arguments configuring the training loop. For details on
      supported arguments, see `training_loop.py`.
  """

  loss_fn = functools.partial(
      tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

  special_tokens = stackoverflow_word_prediction.get_special_tokens(
      vocab_size, num_oov_buckets)
  pad_token = special_tokens.pad
  oov_tokens = special_tokens.oov
  eos_token = special_tokens.eos

  def metrics_fn():
    return [
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_with_oov', masked_tokens=[pad_token]),
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens),
        # Notice BOS never appears in ground truth.
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_no_oov_or_eos',
            masked_tokens=[pad_token, eos_token] + oov_tokens),
        keras_metrics.NumBatchesCounter(),
        keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
    ]

  train_clientdata, validation_clientdata, test_clientdata = (
      tff.simulation.datasets.stackoverflow.load_data())

  vocab = stackoverflow_word_prediction.create_vocab(vocab_size)
  dataset_preprocess_comp = stackoverflow_dataset.create_preprocess_fn(
      vocab=vocab,
      num_oov_buckets=num_oov_buckets,
      client_batch_size=client_batch_size,
      max_sequence_length=sequence_length,
      max_elements_per_client=max_elements_per_user,
      feature_dtypes=train_clientdata.element_type_structure,
      sort_by_date=True)

  input_spec = dataset_preprocess_comp.type_signature.result.element

  model_fn = functools.partial(
      models.create_recurrent_reconstruction_model,
      vocab_size=vocab_size,
      num_oov_buckets=num_oov_buckets,
      embedding_size=embedding_size,
      latent_size=latent_size,
      num_layers=num_layers,
      input_spec=input_spec,
      global_variables_only=global_variables_only)

  def client_weight_fn(local_outputs):
    # Num_tokens is a tensor with type int64[1], to use as a weight need
    # a float32 scalar.
    return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

  iterative_process = iterative_process_builder(
      model_fn,
      loss_fn=loss_fn,
      metrics_fn=metrics_fn,
      client_weight_fn=client_weight_fn,
      dataset_split_fn_builder=functools.partial(
          federated_trainer_utils.build_dataset_split_fn,
          split_dataset_strategy=split_dataset_strategy,
          split_dataset_proportion=split_dataset_proportion))

  base_eval_computation = evaluation_computation_builder(
      model_fn,
      loss_fn=loss_fn,
      metrics_fn=metrics_fn,
      dataset_split_fn_builder=functools.partial(
          federated_trainer_utils.build_dataset_split_fn,
          split_dataset_strategy=split_dataset_strategy,
          split_dataset_proportion=split_dataset_proportion))

  if compose_dataset_computation:
    # Compose dataset computations with client training and evaluation to avoid
    # linear cost of computing centrally. This changes the expected input of
    # the `IterativeProcess` and `tff.Computation` to be a list of client IDs
    # instead of datasets.
    training_process = (
        tff.simulation.compose_dataset_computation_with_iterative_process(
            dataset_preprocess_comp, iterative_process))
    training_process = (
        tff.simulation.compose_dataset_computation_with_iterative_process(
            train_clientdata.dataset_computation, training_process))
    training_process.get_model_weights = iterative_process.get_model_weights

    base_eval_computation = (
        tff.simulation.compose_dataset_computation_with_computation(
            dataset_preprocess_comp, base_eval_computation))
    val_computation = (
        tff.simulation.compose_dataset_computation_with_computation(
            validation_clientdata.dataset_computation, base_eval_computation))
    test_computation = (
        tff.simulation.compose_dataset_computation_with_computation(
            test_clientdata.dataset_computation, base_eval_computation))

    # Create client sampling functions for each of train/val/test.
    # We need to sample client IDs, not datasets, and we do not need to apply
    # `dataset_preprocess_comp` since this is applied as part of the training
    # process and evaluation computation.
    train_client_datasets_fn = federated_trainer_utils.build_list_sample_fn(
        train_clientdata.client_ids, size=clients_per_round, replace=False)
    val_client_datasets_fn = federated_trainer_utils.build_list_sample_fn(
        validation_clientdata.client_ids, size=clients_per_round, replace=False)
    test_client_datasets_fn = federated_trainer_utils.build_list_sample_fn(
        test_clientdata.client_ids, size=clients_per_round, replace=False)
  else:
    training_process = iterative_process
    val_computation = base_eval_computation
    test_computation = base_eval_computation
    # Apply dataset computations.
    train_clientdata = train_clientdata.preprocess(dataset_preprocess_comp)
    validation_clientdata = validation_clientdata.preprocess(
        dataset_preprocess_comp)
    test_clientdata = test_clientdata.preprocess(dataset_preprocess_comp)

    # Create client sampling functions for each of train/val/test.
    train_client_datasets_fn = functools.partial(
        tff.simulation.build_uniform_sampling_fn(train_clientdata.client_ids),
        size=clients_per_round)
    val_client_datasets_fn = functools.partial(
        tff.simulation.build_uniform_sampling_fn(
            validation_clientdata.client_ids),
        size=clients_per_round)
    test_client_datasets_fn = functools.partial(
        tff.simulation.build_uniform_sampling_fn(test_clientdata.client_ids),
        size=clients_per_round)

  # Create final evaluation functions to pass to `training_loop`.
  val_fn = federated_trainer_utils.build_eval_fn(
      evaluation_computation=val_computation,
      client_datasets_fn=val_client_datasets_fn,
      get_model=training_process.get_model_weights)
  test_fn = federated_trainer_utils.build_eval_fn(
      evaluation_computation=test_computation,
      client_datasets_fn=test_client_datasets_fn,
      get_model=training_process.get_model_weights)
  test_fn = functools.partial(test_fn, round_num=0)

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=train_client_datasets_fn,
      validation_fn=val_fn,
      test_fn=test_fn,
      total_rounds=total_rounds,
      experiment_name=experiment_name,
      root_output_dir=root_output_dir,
      **kwargs)
예제 #13
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tff.backends.native.set_local_execution_context(max_fanout=10)

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        FLAGS.vocab_size)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        ]

    train_dataset, _ = stackoverflow_word_prediction.get_federated_datasets(
        vocab_size=FLAGS.vocab_size,
        train_client_batch_size=FLAGS.client_batch_size,
        train_client_epochs_per_round=FLAGS.client_epochs_per_round,
        max_sequence_length=FLAGS.sequence_length,
        max_elements_per_train_client=FLAGS.max_elements_per_user)
    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=FLAGS.vocab_size,
        max_sequence_length=FLAGS.sequence_length,
        num_validation_examples=FLAGS.num_validation_examples)

    if FLAGS.uniform_weighting:
        client_weighting = tff.learning.ClientWeighting.UNIFORM
    else:
        client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=validation_dataset.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )
        if FLAGS.noise_multiplier <= 0:
            raise ValueError(
                'noise_multiplier must be positive if DP is enabled.')
        if FLAGS.clip is None or FLAGS.clip <= 0:
            raise ValueError('clip must be positive if DP is enabled.')

        if not FLAGS.adaptive_clip_learning_rate:
            aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
                noise_multiplier=FLAGS.noise_multiplier,
                clients_per_round=FLAGS.clients_per_round,
                clip=FLAGS.clip)
        else:
            if FLAGS.adaptive_clip_learning_rate <= 0:
                raise ValueError(
                    'adaptive_clip_learning_rate must be positive if '
                    'adaptive clipping is enabled.')
            aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive(
                noise_multiplier=FLAGS.noise_multiplier,
                clients_per_round=FLAGS.clients_per_round,
                initial_l2_norm_clip=FLAGS.clip,
                target_unclipped_quantile=FLAGS.target_unclipped_quantile,
                learning_rate=FLAGS.adaptive_clip_learning_rate)
    else:
        if FLAGS.uniform_weighting:
            aggregation_factory = tff.aggregators.UnweightedMeanFactory()
        else:
            aggregation_factory = tff.aggregators.MeanFactory()

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weighting=client_weighting,
        client_optimizer_fn=client_optimizer_fn,
        model_update_aggregation_factory=aggregation_factory)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    validation_fn = lambda state, round_num: evaluate_fn(state.model)

    evaluate_test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    test_fn = lambda state: evaluate_test_fn(state.model)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    # Log hyperparameters to CSV
    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                               FLAGS.experiment_name)
    utils_impl.create_directory_if_not_exists(results_dir)
    hparam_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=validation_fn,
                      test_fn=test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)