コード例 #1
0
 def test_raises_no_repeat_and_no_take(self):
     with self.assertRaisesRegex(
             ValueError, 'Argument client_epochs_per_round is set to -1'):
         stackoverflow_lr_dataset.get_stackoverflow_datasets(
             vocab_tokens_size=1000,
             vocab_tags_size=500,
             max_training_elements_per_user=128,
             client_batch_size=10,
             client_epochs_per_round=-1,
             max_batches_per_user=-1,
             num_validation_examples=500)
コード例 #2
0
 def test_take_with_repeat(self):
     so_train, _, _ = stackoverflow_lr_dataset.get_stackoverflow_datasets(
         vocab_tokens_size=1000,
         vocab_tags_size=500,
         max_training_elements_per_user=128,
         client_batch_size=10,
         client_epochs_per_round=-1,
         max_batches_per_user=8,
         num_validation_examples=500)
     for i in range(10):
         client_ds = so_train.create_tf_dataset_for_client(
             so_train.client_ids[i])
         self.assertEqual(_compute_length_of_dataset(client_ds), 8)
コード例 #3
0
  def test_stackoverflow_dataset_structure(self):
    stackoverflow_train, stackoverflow_validation, stackoverflow_test = stackoverflow_lr_dataset.get_stackoverflow_datasets(
        vocab_tokens_size=100,
        vocab_tags_size=5,
        max_training_elements_per_user=10,
        client_batch_size=10,
        client_epochs_per_round=1,
        num_validation_examples=10000)
    self.assertEqual(len(stackoverflow_train.client_ids), 342477)
    sample_train_ds = stackoverflow_train.create_tf_dataset_for_client(
        stackoverflow_train.client_ids[0])

    train_batch = next(iter(sample_train_ds))
    valid_batch = next(iter(stackoverflow_validation))
    test_batch = next(iter(stackoverflow_test))
    self.assertEqual(train_batch[0].shape.as_list(), [10, 100])
    self.assertEqual(train_batch[1].shape.as_list(), [10, 5])
    self.assertEqual(valid_batch[0].shape.as_list(), [TEST_BATCH_SIZE, 100])
    self.assertEqual(valid_batch[1].shape.as_list(), [TEST_BATCH_SIZE, 5])
    self.assertEqual(test_batch[0].shape.as_list(), [TEST_BATCH_SIZE, 100])
    self.assertEqual(test_batch[1].shape.as_list(), [TEST_BATCH_SIZE, 5])
コード例 #4
0
def run_federated(
        iterative_process_builder: Callable[...,
                                            tff.templates.IterativeProcess],
        client_epochs_per_round: int,
        client_batch_size: int,
        clients_per_round: int,
        max_batches_per_client: Optional[int] = -1,
        client_datasets_random_seed: Optional[int] = None,
        vocab_tokens_size: Optional[int] = 10000,
        vocab_tags_size: Optional[int] = 500,
        max_elements_per_user: Optional[int] = 1000,
        num_validation_examples: Optional[int] = 10000,
        total_rounds: Optional[int] = 1500,
        experiment_name: Optional[str] = 'federated_so_lr',
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        max_eval_batches: Optional[int] = None,
        **kwargs):
    """Runs an iterative process on the Stack Overflow logistic regression task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Moreover, the server state must have an attribute `model` of type
  `tff.learning.ModelWeights`.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, and
      returns a `tff.templates.IterativeProcess`. The `model_fn` must return a
      `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    max_batches_per_client: An optional int specifying the number of batches
      taken by each client at each round. If `-1`, the entire client dataset is
      used.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    vocab_tokens_size: Integer dictating the number of most frequent words to
      use in the vocabulary.
    vocab_tags_size: Integer dictating the number of most frequent tags to use
      in the label creation.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    num_validation_examples: The number of test examples to use for validation.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    max_eval_batches: If set to a positive integer, evaluation datasets are
      capped to at most that many batches. If set to None or a nonpositive
      integer, the full evaluation datasets are used.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """

    stackoverflow_train, _, _ = stackoverflow_lr_dataset.get_stackoverflow_datasets(
        vocab_tokens_size=vocab_tokens_size,
        vocab_tags_size=vocab_tags_size,
        client_batch_size=client_batch_size,
        client_epochs_per_round=client_epochs_per_round,
        max_training_elements_per_user=max_elements_per_user,
        max_batches_per_user=max_batches_per_client,
        num_validation_examples=num_validation_examples)

    _, stackoverflow_validation, stackoverflow_test = stackoverflow_lr_dataset.get_centralized_datasets(
        train_batch_size=client_batch_size,
        vocab_tokens_size=vocab_tokens_size,
        vocab_tags_size=vocab_tags_size,
        num_validation_examples=num_validation_examples,
        max_validation_batches=max_eval_batches,
        max_test_batches=max_eval_batches)

    input_spec = stackoverflow_train.create_tf_dataset_for_client(
        stackoverflow_train.client_ids[0]).element_spec

    model_builder = functools.partial(
        stackoverflow_lr_models.create_logistic_model,
        vocab_tokens_size=vocab_tokens_size,
        vocab_tags_size=vocab_tags_size)

    loss_builder = functools.partial(tf.keras.losses.BinaryCrossentropy,
                                     from_logits=False,
                                     reduction=tf.keras.losses.Reduction.SUM)

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    training_process = iterative_process_builder(tff_model_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset=stackoverflow_train,
        train_clients_per_round=clients_per_round,
        random_seed=client_datasets_random_seed)

    evaluate_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=stackoverflow_validation,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    test_fn = training_utils.build_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process=training_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      test_fn=test_fn,
                      total_rounds=total_rounds,
                      experiment_name=experiment_name,
                      root_output_dir=root_output_dir,
                      **kwargs)