Exemple #1
0
    def test_raises_non_callable_get_model_weights_attribute(self):
        class BadIterativeProcess(tff.templates.IterativeProcess):
            def __init__(self):
                pass

            def initialize(self):
                return {}

            def next(self, state, data):
                return {}

        iterative_process = BadIterativeProcess()
        iterative_process.get_model_weights = 2
        federated_data = [[_batch_fn()]]

        def client_datasets_fn(round_num):
            del round_num
            return federated_data

        def validation_fn(model, round_num):
            del model, round_num
            return {}

        with self.assertRaisesRegex(
                training_loop.IterativeProcessCompatibilityError,
                _COMPATIBILITY_ERROR_MESSAGE):
            training_loop.run(iterative_process=iterative_process,
                              client_datasets_fn=client_datasets_fn,
                              validation_fn=validation_fn,
                              total_rounds=10,
                              experiment_name='bad_iterative_process')
  def test_fn_writes_metrics(self):
    experiment_name = 'test_metrics'
    iterative_process = _build_federated_averaging_process()
    batch = _batch_fn()
    federated_data = [[batch]]

    def client_datasets_fn(round_num):
      del round_num
      return federated_data

    def evaluate(model):
      keras_model = tff.simulation.models.mnist.create_keras_model(
          compile_model=True)
      model.assign_weights_to(keras_model)
      return {'loss': keras_model.evaluate(batch.x, batch.y)}

    root_output_dir = self.get_temp_dir()
    training_loop.run(
        iterative_process=iterative_process,
        client_datasets_fn=client_datasets_fn,
        validation_fn=evaluate,
        total_rounds=1,
        experiment_name=experiment_name,
        root_output_dir=root_output_dir,
        rounds_per_eval=10,
        test_fn=evaluate)

    results_dir = os.path.join(root_output_dir, 'results', experiment_name)

    scalar_manager = metrics_manager.ScalarMetricsManager(results_dir)
    fieldnames, metrics = scalar_manager.get_metrics()
    self.assertLen(metrics, 2)
    self.assertIn('eval/loss', fieldnames)
    self.assertIn('test/loss', fieldnames)
  def test_raises_no_model_attribute_in_state(self):

    class BadIterativeProcess(tff.templates.IterativeProcess):
      """Converts iterative process results from anonymous tuples."""

      def __init__(self):
        pass

      def initialize(self):
        return {}

      def next(self, state, data):
        return {}

    iterative_process = BadIterativeProcess()
    federated_data = [[_batch_fn()]]

    def client_datasets_fn(round_num):
      del round_num
      return federated_data

    def validation_fn(model):
      del model
      return {}

    with self.assertRaisesRegex(TypeError,
                                'The server state must have a model attribute'):
      training_loop.run(
          iterative_process=iterative_process,
          client_datasets_fn=client_datasets_fn,
          validation_fn=validation_fn,
          total_rounds=10,
          experiment_name='bad_iterative_process')
Exemple #4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags(
        'client')
    server_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags(
        'server')

    compression_dict = utils_impl.lookup_flag_values(compression_flags)
    dp_dict = utils_impl.lookup_flag_values(dp_flags)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`."""

        model_trainable_variables = model_fn().trainable_variables

        # Most logic for deciding what to run is here.
        aggregation_factory = fl_utils.build_aggregator(
            compression_flags=compression_dict,
            dp_flags=dp_dict,
            num_clients=get_total_num_clients(FLAGS.task),
            num_clients_per_round=FLAGS.clients_per_round,
            num_rounds=FLAGS.total_rounds,
            client_template=model_trainable_variables)

        return tff.learning.build_federated_averaging_process(
            model_fn=model_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weighting=tff.learning.ClientWeighting.UNIFORM,
            client_optimizer_fn=client_optimizer_fn,
            model_update_aggregation_factory=aggregation_factory)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(task_spec)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
    def test_run_federated(self, task_name, config_fn):
        task_spec = training_specs.TaskSpec(
            iterative_process_builder=iterative_process_builder,
            client_epochs_per_round=1,
            client_batch_size=10,
            clients_per_round=1,
            sampling_random_seed=1)
        runner_spec = config_fn(task_spec)

        total_rounds = 1
        root_output_dir = self.get_temp_dir()
        exp_name = 'test_run_federated'

        training_loop.run(
            iterative_process=runner_spec.iterative_process,
            client_datasets_fn=runner_spec.client_datasets_fn,
            validation_fn=runner_spec.validation_fn,
            # For efficiency, we avoid using the entire test set here
            test_fn=None,
            total_rounds=total_rounds,
            root_output_dir=root_output_dir,
            experiment_name=exp_name)

        results_dir = os.path.join(root_output_dir, 'results', exp_name)
        self.assertTrue(tf.io.gfile.exists(results_dir))

        scalar_manager = tff.simulation.CSVMetricsManager(
            os.path.join(results_dir, 'experiment.metrics.csv'))
        fieldnames, metrics = scalar_manager.get_metrics()

        self.assertIn(
            'train/train/loss',
            fieldnames,
            msg=
            'The output metrics should have a `train/loss` column if training '
            'is successful.')
        self.assertIn(
            'eval/loss',
            fieldnames,
            msg=
            'The output metrics should have a `train/loss` column if validation'
            ' metrics computation is successful.')
        self.assertLen(
            metrics,
            total_rounds + 1,
            msg='The number of rows in the metrics CSV should be the number of '
            'training rounds + 1 (as there is an extra row for validation set'
            'metrics after training has completed.')
Exemple #6
0
  def test_fedavg_training_decreases_loss(self):
    batch = _batch_fn()
    federated_data = [[batch]]
    iterative_process = _build_federated_averaging_process()

    def client_datasets_fn(round_num):
      del round_num
      return federated_data

    def validation_fn(model):
      keras_model = tff.simulation.models.mnist.create_keras_model(
          compile_model=True)
      model.assign_weights_to(keras_model)
      return {'loss': keras_model.evaluate(batch.x, batch.y)}

    initial_state = iterative_process.initialize()

    root_output_dir = self.get_temp_dir()
    final_state = training_loop.run(
        iterative_process=iterative_process,
        client_datasets_fn=client_datasets_fn,
        validation_fn=validation_fn,
        total_rounds=1,
        experiment_name='fedavg_decreases_loss',
        root_output_dir=root_output_dir)
    self.assertLess(
        validation_fn(final_state.model)['loss'],
        validation_fn(initial_state.model)['loss'])
Exemple #7
0
    def test_raises_non_callable_client_dataset(self):
        iterative_process = _build_federated_averaging_process()
        client_dataset = [[_batch_fn()]]

        def validation_fn(model, round_num):
            del model, round_num
            return {}

        root_output_dir = self.get_temp_dir()
        with self.assertRaises(TypeError):
            training_loop.run(iterative_process=iterative_process,
                              client_datasets_fn=client_dataset,
                              validation_fn=validation_fn,
                              total_rounds=10,
                              experiment_name='non_callable_client_dataset',
                              root_output_dir=root_output_dir)
Exemple #8
0
    def test_raises_non_callable_evaluate_fn(self):
        iterative_process = _build_federated_averaging_process()
        federated_data = [[_batch_fn()]]

        def client_datasets_fn(round_num):
            del round_num
            return federated_data

        metrics_dict = {}
        root_output_dir = self.get_temp_dir()
        with self.assertRaises(TypeError):
            training_loop.run(iterative_process=iterative_process,
                              client_datasets_fn=client_datasets_fn,
                              validation_fn=metrics_dict,
                              total_rounds=10,
                              experiment_name='non_callable_evaluate',
                              root_output_dir=root_output_dir)
Exemple #9
0
    def test_raises_non_str_output_dir(self):
        iterative_process = _build_federated_averaging_process()
        federated_data = [[_batch_fn()]]

        def client_datasets_fn(round_num):
            del round_num
            return federated_data

        def validation_fn(model, round_num):
            del model, round_num
            return {}

        with self.assertRaises(TypeError):
            training_loop.run(iterative_process=iterative_process,
                              client_datasets_fn=client_datasets_fn,
                              validation_fn=validation_fn,
                              total_rounds=10,
                              experiment_name='non_str_output_dir',
                              root_output_dir=1)
Exemple #10
0
    def test_raises_non_iterative_process(self):
        bad_iterative_process = _build_federated_averaging_process().next
        federated_data = [[_batch_fn()]]

        def client_datasets_fn(round_num):
            del round_num
            return federated_data

        def validation_fn(model, round_num):
            del model, round_num
            return {}

        root_output_dir = self.get_temp_dir()
        with self.assertRaises(TypeError):
            training_loop.run(iterative_process=[bad_iterative_process],
                              client_datasets_fn=client_datasets_fn,
                              validation_fn=validation_fn,
                              total_rounds=10,
                              experiment_name='non_iterative_process',
                              root_output_dir=root_output_dir)
    def test_training_loop_works_on_restart(self):
        experiment_name = 'test_metrics'
        iterative_process = _build_federated_averaging_process()
        batch = _batch_fn()
        federated_data = [[batch]]

        def client_datasets_fn(round_num):
            del round_num
            return federated_data

        def validation_fn(model, round_num):
            del model, round_num
            return {}

        root_output_dir = self.get_temp_dir()
        training_loop.run(iterative_process=iterative_process,
                          client_datasets_fn=client_datasets_fn,
                          validation_fn=validation_fn,
                          total_rounds=2,
                          experiment_name=experiment_name,
                          root_output_dir=root_output_dir)

        training_loop.run(iterative_process=iterative_process,
                          client_datasets_fn=client_datasets_fn,
                          validation_fn=validation_fn,
                          total_rounds=4,
                          experiment_name=experiment_name,
                          root_output_dir=root_output_dir)

        csv_file = os.path.join(root_output_dir, 'results', experiment_name,
                                'experiment.metrics.csv')
        metrics_manager = tff.simulation.CSVMetricsManager(csv_file)
        _, metrics = metrics_manager.get_metrics()
        # Test that 4 rounds of metrics are logged (the CSV length adds one for
        # the header).
        self.assertLen(metrics, 5)
Exemple #12
0
    def test_checkpoint_manager_saves_state(self):
        experiment_name = 'checkpoint_manager_saves_state'
        iterative_process = _build_federated_averaging_process()
        federated_data = [[_batch_fn()]]

        def client_datasets_fn(round_num):
            del round_num
            return federated_data

        def validation_fn(model, round_num):
            del model, round_num
            return {}

        root_output_dir = self.get_temp_dir()
        final_state = training_loop.run(iterative_process=iterative_process,
                                        client_datasets_fn=client_datasets_fn,
                                        validation_fn=validation_fn,
                                        total_rounds=1,
                                        experiment_name=experiment_name,
                                        root_output_dir=root_output_dir)
        final_model = iterative_process.get_model_weights(final_state)

        ckpt_manager = tff.simulation.FileCheckpointManager(
            os.path.join(root_output_dir, 'checkpoints', experiment_name))
        restored_state, restored_round = ckpt_manager.load_latest_checkpoint(
            final_state)

        self.assertEqual(restored_round, 0)

        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=True)
        restored_model = iterative_process.get_model_weights(restored_state)

        restored_model.assign_weights_to(keras_model)
        restored_loss = keras_model.test_on_batch(federated_data[0][0].x,
                                                  federated_data[0][0].y)

        final_model.assign_weights_to(keras_model)
        final_loss = keras_model.test_on_batch(federated_data[0][0].x,
                                               federated_data[0][0].y)
        self.assertEqual(final_loss, restored_loss)
Exemple #13
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')

  def iterative_process_builder(
      model_fn: Callable[[], tff.learning.Model],
  ) -> tff.templates.IterativeProcess:
    """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

    logging.info('Trainable weights:')
    for weight in model_fn().weights.trainable:
      logging.info('name: %s  shape: %s', weight.name, weight.shape)

    if FLAGS.uniform_weighting:
      client_weighting = tff.learning.ClientWeighting.UNIFORM
    elif FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp':

      def client_weighting(local_outputs):
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)
    else:
      client_weighting = None

    if FLAGS.noise_multiplier is None:
      if FLAGS.uniform_weighting:
        aggregation_factory = tff.aggregators.UnweightedMeanFactory()
      else:
        aggregation_factory = tff.aggregators.MeanFactory()
      if FLAGS.clip is not None:
        if FLAGS.clip <= 0:
          raise ValueError('clip must be positive if clipping is enabled.')
        if FLAGS.adaptive_clip_learning_rate is None:
          clip = FLAGS.clip
        else:
          if FLAGS.adaptive_clip_learning_rate <= 0:
            raise ValueError('adaptive_clip_learning_rate must be positive if '
                             'adaptive clipping is enabled.')
          clip = tff.aggregators.PrivateQuantileEstimationProcess.no_noise(
              initial_estimate=FLAGS.clip,
              target_quantile=FLAGS.target_unclipped_quantile,
              learning_rate=FLAGS.adaptive_clip_learning_rate)
        aggregation_factory = tff.aggregators.clipping_factory(
            clip, aggregation_factory)
    else:
      if not FLAGS.uniform_weighting:
        raise ValueError(
            'Differential privacy is only implemented for uniform weighting.')
      if FLAGS.noise_multiplier <= 0:
        raise ValueError('noise_multiplier must be positive if DP is enabled.')
      if FLAGS.clip is None or FLAGS.clip <= 0:
        raise ValueError('clip must be positive if DP is enabled.')
      if FLAGS.adaptive_clip_learning_rate is None:
        aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
            noise_multiplier=FLAGS.noise_multiplier,
            clients_per_round=FLAGS.clients_per_round,
            clip=FLAGS.clip)
      else:
        if FLAGS.adaptive_clip_learning_rate <= 0:
          raise ValueError('adaptive_clip_learning_rate must be positive if '
                           'adaptive clipping is enabled.')
        aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive(
            noise_multiplier=FLAGS.noise_multiplier,
            clients_per_round=FLAGS.clients_per_round,
            initial_l2_norm_clip=FLAGS.clip,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            learning_rate=FLAGS.adaptive_clip_learning_rate)

    return tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weighting=client_weighting,
        client_optimizer_fn=client_optimizer_fn,
        model_update_aggregation_factory=aggregation_factory)

  task_spec = training_specs.TaskSpec(
      iterative_process_builder=iterative_process_builder,
      client_epochs_per_round=FLAGS.client_epochs_per_round,
      client_batch_size=FLAGS.client_batch_size,
      clients_per_round=FLAGS.clients_per_round,
      client_datasets_random_seed=FLAGS.client_datasets_random_seed)

  if FLAGS.task == 'cifar100':
    runner_spec = federated_cifar100.configure_training(task_spec)
  elif FLAGS.task == 'emnist_cr':
    runner_spec = federated_emnist.configure_training(task_spec)
  elif FLAGS.task == 'emnist_ae':
    runner_spec = federated_emnist_ae.configure_training(task_spec)
  elif FLAGS.task == 'shakespeare':
    runner_spec = federated_shakespeare.configure_training(task_spec)
  elif FLAGS.task == 'stackoverflow_nwp':
    runner_spec = federated_stackoverflow.configure_training(task_spec)
  elif FLAGS.task == 'stackoverflow_lr':
    runner_spec = federated_stackoverflow_lr.configure_training(task_spec)
  else:
    raise ValueError(
        '--task flag {} is not supported, must be one of {}.'.format(
            FLAGS.task, _SUPPORTED_TASKS))

  _write_hparam_flags()

  training_loop.run(
      iterative_process=runner_spec.iterative_process,
      client_datasets_fn=runner_spec.client_datasets_fn,
      validation_fn=runner_spec.validation_fn,
      test_fn=runner_spec.test_fn,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      root_output_dir=FLAGS.root_output_dir,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemple #14
0
def run_federated(
    iterative_process_builder: Callable[..., tff.templates.IterativeProcess],
    client_epochs_per_round: int,
    client_batch_size: int,
    clients_per_round: int,
    client_datasets_random_seed: Optional[int] = None,
    vocab_tokens_size: Optional[int] = 10000,
    vocab_tags_size: Optional[int] = 500,
    max_elements_per_user: Optional[int] = 1000,
    num_validation_examples: Optional[int] = 10000,
    total_rounds: Optional[int] = 1500,
    experiment_name: Optional[str] = 'federated_so_lr',
    root_output_dir: Optional[str] = '/tmp/fed_opt',
    max_eval_batches: Optional[int] = None,
    **kwargs):
  """Runs an iterative process on the Stack Overflow logistic regression task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, and
      returns a `tff.templates.IterativeProcess`. The `model_fn` must return a
      `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    vocab_tokens_size: Integer dictating the number of most frequent words to
      use in the vocabulary.
    vocab_tags_size: Integer dictating the number of most frequent tags to use
      in the label creation.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    num_validation_examples: The number of test examples to use for validation.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    max_eval_batches: If set to a positive integer, evaluation datasets are
      capped to at most that many batches. If set to None or a nonpositive
      integer, the full evaluation datasets are used.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """

  stackoverflow_train, _ = stackoverflow_tag_prediction.get_federated_datasets(
      word_vocab_size=vocab_tokens_size,
      tag_vocab_size=vocab_tags_size,
      train_client_batch_size=client_batch_size,
      train_client_epochs_per_round=client_epochs_per_round,
      max_elements_per_train_client=max_elements_per_user)

  _, stackoverflow_validation, stackoverflow_test = stackoverflow_tag_prediction.get_centralized_datasets(
      train_batch_size=client_batch_size,
      word_vocab_size=vocab_tokens_size,
      tag_vocab_size=vocab_tags_size,
      num_validation_examples=num_validation_examples)

  if max_eval_batches and max_eval_batches >= 1:
    stackoverflow_validation = stackoverflow_validation.take(max_eval_batches)
    stackoverflow_test = stackoverflow_test.take(max_eval_batches)

  input_spec = stackoverflow_train.create_tf_dataset_for_client(
      stackoverflow_train.client_ids[0]).element_spec

  model_builder = functools.partial(
      stackoverflow_lr_models.create_logistic_model,
      vocab_tokens_size=vocab_tokens_size,
      vocab_tags_size=vocab_tags_size)

  loss_builder = functools.partial(
      tf.keras.losses.BinaryCrossentropy,
      from_logits=False,
      reduction=tf.keras.losses.Reduction.SUM)

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  training_process = iterative_process_builder(tff_model_fn)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      dataset=stackoverflow_train,
      clients_per_round=clients_per_round,
      random_seed=client_datasets_random_seed)

  evaluate_fn = training_utils.build_centralized_evaluate_fn(
      model_builder=model_builder,
      eval_dataset=stackoverflow_validation,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  test_fn = training_utils.build_centralized_evaluate_fn(
      model_builder=model_builder,
      # Use both val and test for symmetry with other experiments, which
      # evaluate on the entire test set.
      eval_dataset=stackoverflow_validation.concatenate(stackoverflow_test),
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=client_datasets_fn,
      validation_fn=evaluate_fn,
      test_fn=test_fn,
      total_rounds=total_rounds,
      experiment_name=experiment_name,
      root_output_dir=root_output_dir,
      **kwargs)
Exemple #15
0
def run_federated(
        iterative_process_builder: Callable[...,
                                            tff.templates.IterativeProcess],
        client_epochs_per_round: int,
        client_batch_size: int,
        clients_per_round: int,
        schedule: Optional[str] = 'none',
        beta: Optional[float] = 0.,
        max_batches_per_client: Optional[int] = -1,
        client_datasets_random_seed: Optional[int] = None,
        model: Optional[str] = 'cnn',
        total_rounds: Optional[int] = 1500,
        experiment_name: Optional[str] = 'federated_synthetic',
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        max_eval_batches: Optional[int] = None,
        alpha: Optional[float] = 0.,
        beta_data: Optional[float] = 0.,
        iid: Optional[int] = 0,
        num_users: Optional[int] = 100,
        **kwargs):
    """Runs an iterative process on the EMNIST character recognition task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Moreover, the server state must have an attribute `model` of type
  `tff.learning.ModelWeights`.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, and
      returns a `tff.templates.IterativeProcess`. The `model_fn` must return a
      `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    max_batches_per_client: An optional int specifying the number of batches
      taken by each client at each round. If `-1`, the entire client dataset is
      used.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    model: A string specifying the model used for character recognition.
      Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely
      connected 2-layer model (respectively).
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    max_eval_batches: If set to a positive integer, evaluation datasets are
      capped to at most that many batches. If set to None or a nonpositive
      integer, the full evaluation datasets are used.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """
    logging.info(f' DATA PARAMS: ')
    logging.info(f'             Num Users: {num_users}')
    logging.info(f'             alpha: {alpha}')
    logging.info(f'             beta: {beta_data}')
    logging.info(f'             iid: {iid}')
    train_data, test_data, federated_test = synthetic_dataset.generate_federated_softmax_data(
        batch_size=client_batch_size,
        client_epochs_per_round=client_epochs_per_round,
        test_batch_size=100,
        alpha=alpha,
        beta=beta_data,
        iid=iid,
        num_users=num_users)

    input_spec = train_data.create_tf_dataset_for_client(
        train_data.client_ids[0]).element_spec

    model_builder = functools.partial(create_logistic_regression_model)

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    training_process = iterative_process_builder(tff_model_fn)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=test_data,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    test_fn = training_utils.build_unweighted_test_fn(
        federated_eval_dataset=federated_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())
    try:
        var = kwargs['hparam_dict']['var_q_clients']
        print(f'Variance: {var}')
        q_client = np.load(
            f'/home/monica/AVAIL_VECTORS/q_client_{var}_synthetic.npy')
    #if q_client is None:
    #logging.info('Could not load q_client - initializing random availabilities')
    #q_client=None
    except:
        logging.info(
            'Could not load q_client - initializing random availabilities')
        q_client = None

    if schedule == 'none':
        client_datasets_fn = training_utils.build_client_datasets_fn(
            train_dataset=train_data,
            train_clients_per_round=clients_per_round,
            random_seed=client_datasets_random_seed,
            min_clients=kwargs['hparam_dict']['min_clients'],
            var_q_clients=kwargs['hparam_dict']['var_q_clients'],
            f_mult=kwargs['hparam_dict']['f_mult'],
            f_intercept=kwargs['hparam_dict']['f_intercept'],
            sine_wave=kwargs['hparam_dict']['sine_wave'],
            use_p=True,
            q_client=q_client,
        )
        training_loop.run(iterative_process=training_process,
                          client_datasets_fn=client_datasets_fn,
                          validation_fn=evaluate_fn,
                          test_fn=test_fn,
                          total_rounds=total_rounds,
                          experiment_name=experiment_name,
                          root_output_dir=root_output_dir,
                          **kwargs)
    elif schedule == 'loss':
        if 'loss_pool_size' in kwargs['hparam_dict'] and kwargs['hparam_dict'][
                'loss_pool_size'] is not None:
            loss_pool_size = kwargs['hparam_dict']['loss_pool_size']
            logging.info(f'Loss pool size: {loss_pool_size}')

            client_datasets_fn = training_utils.build_client_datasets_fn(
                train_dataset=train_data,
                train_clients_per_round=loss_pool_size,
                random_seed=client_datasets_random_seed,
                min_clients=kwargs['hparam_dict']['min_clients'],
                var_q_clients=kwargs['hparam_dict']['var_q_clients'],
                f_mult=kwargs['hparam_dict']['f_mult'],
                f_intercept=kwargs['hparam_dict']['f_intercept'],
                sine_wave=kwargs['hparam_dict']['sine_wave'],
                use_p=True,
                q_client=q_client)
            training_loop_loss.run(iterative_process=training_process,
                                   client_datasets_fn=client_datasets_fn,
                                   validation_fn=evaluate_fn,
                                   test_fn=test_fn,
                                   total_rounds=total_rounds,
                                   total_clients=loss_pool_size,
                                   experiment_name=experiment_name,
                                   root_output_dir=root_output_dir,
                                   **kwargs)
        else:
            raise ValueError('Loss pool size not specified')
    else:
        init_p = kwargs['hparam_dict']['initialize_p']
        logging.info(f'Initializing as p = {init_p}')
        client_datasets_fn = training_utils.build_availability_client_datasets_fn(
            train_dataset=train_data,
            train_clients_per_round=clients_per_round,
            beta=beta,
            min_clients=kwargs['hparam_dict']['min_clients'],
            var_q_clients=kwargs['hparam_dict']['var_q_clients'],
            f_mult=kwargs['hparam_dict']['f_mult'],
            f_intercept=kwargs['hparam_dict']['f_intercept'],
            sine_wave=kwargs['hparam_dict']['sine_wave'],
            q_client=q_client,
            initialize_p=init_p,
        )
        training_loop_importance.run(iterative_process=training_process,
                                     client_datasets_fn=client_datasets_fn,
                                     validation_fn=evaluate_fn,
                                     test_fn=test_fn,
                                     total_rounds=total_rounds,
                                     experiment_name=experiment_name,
                                     root_output_dir=root_output_dir,
                                     **kwargs)
Exemple #16
0
def run_federated(
        iterative_process_builder: Callable[...,
                                            tff.templates.IterativeProcess],
        client_epochs_per_round: int,
        client_batch_size: int,
        clients_per_round: int,
        client_datasets_random_seed: Optional[int] = None,
        vocab_size: Optional[int] = 10000,
        num_oov_buckets: Optional[int] = 1,
        sequence_length: Optional[int] = 20,
        max_elements_per_user: Optional[int] = 1000,
        num_validation_examples: Optional[int] = 10000,
        embedding_size: Optional[int] = 96,
        latent_size: Optional[int] = 670,
        num_layers: Optional[int] = 1,
        shared_embedding: Optional[bool] = False,
        total_rounds: Optional[int] = 1500,
        experiment_name: Optional[str] = 'federated_so_nwp',
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        **kwargs):
    """Runs an iterative process on the Stack Overflow next word prediction task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, a
      `client_weight_fn` and returns a `tff.templates.IterativeProcess`. The
      `model_fn` must return a `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    vocab_size: Integer dictating the number of most frequent words to use in
      the vocabulary.
    num_oov_buckets: The number of out-of-vocabulary buckets to use.
    sequence_length: The maximum number of words to take for each sequence.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    num_validation_examples: The number of test examples to use for validation.
    embedding_size: The dimension of the word embedding layer.
    latent_size: The dimension of the latent units in the recurrent layers.
    num_layers: The number of stacked recurrent layers to use.
    shared_embedding: Boolean indicating whether to tie input and output
      embeddings.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=vocab_size,
        num_oov_buckets=num_oov_buckets,
        embedding_size=embedding_size,
        latent_size=latent_size,
        num_layers=num_layers,
        shared_embedding=shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_word_prediction.get_special_tokens(
        vocab_size, num_oov_buckets)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
        ]

    train_clientdata, _, _ = tff.simulation.datasets.stackoverflow.load_data()

    # TODO(b/161914546): consider moving evaluation to use
    # `tff.learning.build_federated_evaluation` to get metrics over client
    # distributions, as well as the example weight means from this centralized
    # evaluation.
    _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets(
        vocab_size=vocab_size,
        max_sequence_length=sequence_length,
        num_validation_examples=num_validation_examples,
        num_oov_buckets=num_oov_buckets)

    train_dataset_preprocess_comp = stackoverflow_word_prediction.create_preprocess_fn(
        vocab=stackoverflow_word_prediction.create_vocab(vocab_size),
        num_oov_buckets=num_oov_buckets,
        client_batch_size=client_batch_size,
        client_epochs_per_round=client_epochs_per_round,
        max_sequence_length=sequence_length,
        max_elements_per_client=max_elements_per_user)

    input_spec = train_dataset_preprocess_comp.type_signature.result.element

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    iterative_process = iterative_process_builder(
        tff_model_fn, client_weight_fn=client_weight_fn)

    training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
        train_dataset_preprocess_comp, iterative_process)

    training_process.get_model_weights = iterative_process.get_model_weights

    client_datasets_fn = training_utils.build_client_datasets_fn(
        dataset=train_clientdata,
        clients_per_round=clients_per_round,
        random_seed=client_datasets_random_seed)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights)

    test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process=training_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=validation_fn,
                      test_fn=test_fn,
                      total_rounds=total_rounds,
                      experiment_name=experiment_name,
                      root_output_dir=root_output_dir,
                      **kwargs)
Exemple #17
0
def run_federated(
        iterative_process_builder: Callable[...,
                                            tff.templates.IterativeProcess],
        client_epochs_per_round: int,
        client_batch_size: int,
        clients_per_round: int,
        max_batches_per_client: Optional[int] = -1,
        client_datasets_random_seed: Optional[int] = None,
        sequence_length: Optional[int] = 80,
        total_rounds: Optional[int] = 1500,
        experiment_name: Optional[str] = 'federated_shakespeare',
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        max_eval_batches: Optional[int] = None,
        **kwargs):
    """Runs an iterative process on a Shakespeare next character prediction task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Moreover, the server state must have an attribute `model` of type
  `tff.learning.ModelWeights`.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, and
      a `client_weight_fn`, and returns a `tff.templates.IterativeProcess`. The
      `model_fn` must return a `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    max_batches_per_client: An optional int specifying the number of batches
      taken by each client at each round. If `-1`, the entire client dataset is
      used.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    sequence_length: An int specifying the length of the character sequences
      used for prediction.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    max_eval_batches: If set to a positive integer, evaluation datasets are
      capped to at most that many batches. If set to None or a nonpositive
      integer, the full evaluation datasets are used.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """

    train_clientdata = shakespeare_dataset.construct_character_level_datasets(
        client_batch_size=client_batch_size,
        client_epochs_per_round=client_epochs_per_round,
        sequence_length=sequence_length,
        max_batches_per_client=max_batches_per_client)

    _, test_dataset = shakespeare_dataset.get_centralized_datasets(
        train_batch_size=client_batch_size,
        max_test_batches=max_eval_batches,
        sequence_length=sequence_length)

    model_builder = functools.partial(create_shakespeare_model,
                                      sequence_length=sequence_length)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    input_spec = train_clientdata.create_tf_dataset_for_client(
        train_clientdata.client_ids[0]).element_spec

    def client_weight_fn(local_outputs):
        # Num_tokens is a tensor with type int64[1], to use as a weight need
        # a float32 scalar.
        return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    training_process = iterative_process_builder(
        tff_model_fn, client_weight_fn=client_weight_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        dataset=train_clientdata,
        clients_per_round=clients_per_round,
        random_seed=client_datasets_random_seed)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=test_dataset,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process=training_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      test_fn=evaluate_fn,
                      total_rounds=total_rounds,
                      experiment_name=experiment_name,
                      root_output_dir=root_output_dir,
                      **kwargs)
Exemple #18
0
def run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, emnist_test = emnist_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=FLAGS.only_digits)

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    input_spec = example_dataset.element_spec

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    client_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'client')
    server_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'server')

    def tff_model_fn():
        keras_model = model_builder()
        return tff.learning.from_keras_model(keras_model,
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    if FLAGS.use_compression:
        # We create a `MeasuredProcess` for broadcast process and a
        # `MeasuredProcess` for aggregate process by providing the
        # `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding utilities.
        # The fns are called once for each of the model weights created by
        # tff_model_fn, and return instances of appropriate encoders.
        encoded_broadcast_process = (
            tff.learning.framework.build_encoded_broadcast_process_from_model(
                tff_model_fn, _broadcast_encoder_fn))
        encoded_mean_process = (
            tff.learning.framework.build_encoded_mean_process_from_model(
                tff_model_fn, _mean_encoder_fn))
    else:
        encoded_broadcast_process = None
        encoded_mean_process = None

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        aggregation_process=encoded_mean_process,
        broadcast_process=encoded_broadcast_process)

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
Exemple #19
0
def run_federated(
    iterative_process_builder: Callable[..., tff.templates.IterativeProcess],
    evaluation_computation_builder: Callable[..., tff.Computation],
    *,  # Caller passes below args by name.
    client_batch_size: int,
    clients_per_round: int,
    global_variables_only: bool,
    # Begin task-specific parameters.
    split_by_user: bool,
    split_train_fraction: float,
    split_val_fraction: float,
    normalize_ratings: bool,
    max_examples_per_user: Optional[int],
    num_items: int,
    num_latent_factors: int,
    add_biases: bool,
    l2_regularization: float,
    spreadout_lambda: float,
    accuracy_threshold: float,
    dataset_path:
    str = 'http://files.grouplens.org/datasets/movielens/ml-1m.zip',
    # End task-specific parameters.
    total_rounds: int,
    experiment_name: Optional[str] = 'federated_ml_mf',
    root_output_dir: Optional[str] = '/tmp/fed_recon',
    **kwargs):
  """Runs an iterative process on the MovieLens matrix factorization task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  This algorithm only sends updates for item embeddings every round. User
  embeddings are reconstructed every round based on the latest item embeddings
  and user data.

  We assume that the iterative process has the following functional type
  signatures:

   *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
   *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, a
      `loss_fn`, a `metrics_fn`, and a `client_weight_fn`, and returns a
      `tff.templates.IterativeProcess`. The `model_fn` must return a
      `reconstruction_model.ReconstructionModel`. See `federated_trainer.py` for
      an example.
    evaluation_computation_builder: A function that accepts a no-arg `model_fn`,
      a loss_fn`, and a `metrics_fn`, and returns a `tff.Computation` for
      federated reconstruction evaluation. The `model_fn` must return a
      `reconstruction_model.ReconstructionModel`. See `federated_trainer.py` for
      an example.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    global_variables_only: If True, the `ReconstructionModel` contains
      all model variables as global variables. This can be useful for
      baselines involving aggregating all variables.
    split_by_user: Whether to split MovieLens data into train/val/test by user
      ID or by timestamp. If True, `movielens_dataset.split_tf_datasets` is used
      to partition the set of users into disjoint train/val/test sets. If False,
      `movielens_dataset.split_ratings_df` is used to split each user's data
      into train/val/test portions, so that each user shows up in each data
      partition. Setting to False can be useful for comparing with server-side
      training, since server matrix factorization requires test users to have
      been seen before (since otherwise we don't have trained user embeddings
      for these users).
    split_train_fraction: The fraction of data to use for the train set.
    split_val_fraction: The fraction of data to use for the val set.
      `split_train_fraction` and `split_val_fraction` should sum to no more than
      1. 1 - split_train_fraction - split_val_fraction of the data is left for
      test.
    normalize_ratings: Whether to normalize ratings in 1-5 to be in {-1, -0.5,
      0, 0.5, 1} via a linear scaling.
    max_examples_per_user: If not None, limit the number of rating examples for
      each user to this many examples.
    num_items: Number of items in the preferences matrix.
    num_latent_factors: Dimensionality of the learned user/item embeddings used
      to factorize the preferences matrix.
    add_biases: If True, add three bias terms: (1) user-specific bias, (2)
      item-specific bias, and (3) global bias.
    l2_regularization: The constant to use to scale L2 regularization on all
      weights, including the factorized matrices and the (optional) biases. A
      value of 0.0 indicates no regularization.
    spreadout_lambda: Scaling constant for spreadout regularization on item
      embeddings. This ensures that item embeddings are generally spread far
      apart, and that random items have dissimilar embeddings. See
      `models.EmbeddingSpreadoutRegularizer` for details. A value of 0.0
      indicates no regularization.
    accuracy_threshold: Threshold to use to determine whether a prediction is
      considered correct for metrics.
    dataset_path: URL or local path to the MovieLens 1M dataset. If a URL is
      passed, it is expected to be a .zip archive that will be extracted.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    **kwargs: Additional arguments configuring the training loop. For details on
      supported arguments, see training_loop.py`.
  """

  logging.info('Copying MovieLens data.')
  if tf.io.gfile.exists('/tmp/ml-1m'):
    tf.io.gfile.rmtree('/tmp/ml-1m')

  if dataset_path.startswith('http'):
    movielens_dataset.download_and_extract_data(dataset_path, '/tmp')
  else:
    tf.io.gfile.makedirs('/tmp/ml-1m/')
    tf.io.gfile.copy(
        os.path.join(dataset_path, 'ratings.dat'),
        '/tmp/ml-1m/ratings.dat',
        overwrite=True)
    tf.io.gfile.copy(
        os.path.join(dataset_path, 'movies.dat'),
        '/tmp/ml-1m/movies.dat',
        overwrite=True)
    tf.io.gfile.copy(
        os.path.join(dataset_path, 'users.dat'),
        '/tmp/ml-1m/users.dat',
        overwrite=True)
  logging.info('Finished copying MovieLens data.')

  ratings_df, _, _ = movielens_dataset.load_movielens_data(
      normalize_ratings=normalize_ratings)

  # Split the ratings into training/val/test.
  if split_by_user:
    tf_datasets = movielens_dataset.create_tf_datasets(
        ratings_df=ratings_df,
        personal_model=True,
        batch_size=client_batch_size,
        max_examples_per_user=max_examples_per_user,
        num_local_epochs=1)

    tf_train_datasets, tf_val_datasets, tf_test_datasets = movielens_dataset.split_tf_datasets(
        tf_datasets,
        train_fraction=split_train_fraction,
        val_fraction=split_val_fraction)
  else:
    train_ratings_df, val_ratings_df, test_ratings_df = movielens_dataset.split_ratings_df(
        ratings_df,
        train_fraction=split_train_fraction,
        val_fraction=split_val_fraction)

    tf_train_datasets = movielens_dataset.create_tf_datasets(
        ratings_df=train_ratings_df,
        personal_model=True,
        batch_size=client_batch_size,
        max_examples_per_user=max_examples_per_user,
        num_local_epochs=1)
    tf_val_datasets = movielens_dataset.create_tf_datasets(
        ratings_df=val_ratings_df,
        personal_model=True,
        batch_size=client_batch_size,
        max_examples_per_user=max_examples_per_user,
        num_local_epochs=1)
    tf_test_datasets = movielens_dataset.create_tf_datasets(
        ratings_df=test_ratings_df,
        personal_model=True,
        batch_size=client_batch_size,
        max_examples_per_user=max_examples_per_user,
        num_local_epochs=1)

  model_fn = models.build_reconstruction_model(
      functools.partial(
          models.get_matrix_factorization_model,
          num_users=1,
          num_items=num_items,
          num_latent_factors=num_latent_factors,
          personal_model=True,
          add_biases=add_biases,
          l2_regularization=l2_regularization,
          spreadout_lambda=spreadout_lambda),
      global_variables_only=global_variables_only)

  loss_fn = models.get_loss_fn()
  metrics_fn = models.get_metrics_fn(accuracy_threshold=accuracy_threshold)

  training_process = iterative_process_builder(
      model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn)
  evaluation_computation = evaluation_computation_builder(
      model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn)

  def client_datasets_fn_from_tf_datasets(
      tf_datasets: List[tf.data.Dataset],
      clients_per_round: int,
  ) -> Callable[[int], List[tf.data.Dataset]]:
    """Produces a sampling function for train/val/test from a list of datasets."""
    sample_clients_fn = federated_trainer_utils.build_list_sample_fn(
        list(range(len(tf_datasets))), size=clients_per_round, replace=False)

    def client_datasets_fn(round_num):
      sampled_clients = sample_clients_fn(round_num)
      return [tf_datasets[client_id] for client_id in sampled_clients]

    return client_datasets_fn

  # Create client sampling functions for each of train/val/test.
  train_client_datasets_fn = client_datasets_fn_from_tf_datasets(
      tf_train_datasets, clients_per_round=clients_per_round)
  val_client_datasets_fn = client_datasets_fn_from_tf_datasets(
      tf_val_datasets, clients_per_round=clients_per_round)
  test_client_datasets_fn = client_datasets_fn_from_tf_datasets(
      tf_test_datasets, clients_per_round=clients_per_round)

  # Create final evaluation functions to pass to `training_loop`.
  val_fn = federated_trainer_utils.build_eval_fn(
      evaluation_computation=evaluation_computation,
      client_datasets_fn=val_client_datasets_fn)
  test_fn = federated_trainer_utils.build_eval_fn(
      evaluation_computation=evaluation_computation,
      client_datasets_fn=test_client_datasets_fn)
  test_fn = functools.partial(test_fn, round_num=0)

  logging.info('Starting training loop.')
  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=train_client_datasets_fn,
      validation_fn=val_fn,
      test_fn=test_fn,
      total_rounds=total_rounds,
      experiment_name=experiment_name,
      root_output_dir=root_output_dir,
      **kwargs)
def run_federated(
    iterative_process_builder: Callable[..., tff.templates.IterativeProcess],
    evaluation_computation_builder: Callable[..., tff.Computation],
    client_batch_size: int,
    clients_per_round: int,
    global_variables_only: bool,
    vocab_size: int = 10000,
    num_oov_buckets: int = 1,
    sequence_length: int = 20,
    max_elements_per_user: int = 1000,
    embedding_size: int = 96,
    latent_size: int = 670,
    num_layers: int = 1,
    total_rounds: int = 1500,
    experiment_name: str = 'federated_so_nwp',
    root_output_dir: str = '/tmp/fed_recon',
    split_dataset_strategy: str = federated_trainer_utils
    .SPLIT_STRATEGY_AGGREGATED,
    split_dataset_proportion: int = 2,
    compose_dataset_computation: bool = False,
    **kwargs):
  """Runs an iterative process on the Stack Overflow next word prediction task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  This model only sends updates for its embeddings corresponding to the most
  common words. Embeddings for out of vocabulary buckets are reconstructed on
  device at the beginning of each round, and destroyed at the end of these
  rounds.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, a
      `loss_fn`, a `metrics_fn`, and a `client_weight_fn`, and returns a
      `tff.templates.IterativeProcess`. The `model_fn` must return a
      `reconstruction_model.ReconstructionModel`. See `federated_trainer.py` for
      an example.
    evaluation_computation_builder: A function that accepts a no-arg `model_fn`,
      a loss_fn`, and a `metrics_fn`, and returns a `tff.Computation` for
      federated reconstruction evaluation. The `model_fn` must return a
      `reconstruction_model.ReconstructionModel`. See `federated_trainer.py` for
      an example.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    global_variables_only: If True, the `ReconstructionModel` contains all model
      variables as global variables. This can be useful for baselines involving
      aggregating all variables.
    vocab_size: Integer dictating the number of most frequent words to use in
      the vocabulary.
    num_oov_buckets: The number of out-of-vocabulary buckets to use.
    sequence_length: The maximum number of words to take for each sequence.
    max_elements_per_user: The maximum number of elements processed for each
      client's dataset.
    embedding_size: The dimension of the word embedding layer.
    latent_size: The dimension of the latent units in the recurrent layers.
    num_layers: The number of stacked recurrent layers to use.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    split_dataset_strategy: The method to use to split the data. Must be one of
      `skip`, in which case every `split_dataset_proportion` example is used for
      reconstruction, or `aggregated`, when the first
      1/`split_dataset_proportion` proportion of the examples is used for
      reconstruction.
    split_dataset_proportion: Parameter controlling how much of the data is used
      for reconstruction. If `split_dataset_proportion` is n, then 1 / n of the
      data is used for reconstruction.
    compose_dataset_computation: Whether to compose dataset computation with
      training and evaluation computations. If True, may speed up experiments by
      parallelizing dataset computations in multimachine setups. Not currently
      supported in OSS.
    **kwargs: Additional arguments configuring the training loop. For details on
      supported arguments, see `training_loop.py`.
  """

  loss_fn = functools.partial(
      tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

  special_tokens = stackoverflow_word_prediction.get_special_tokens(
      vocab_size, num_oov_buckets)
  pad_token = special_tokens.pad
  oov_tokens = special_tokens.oov
  eos_token = special_tokens.eos

  def metrics_fn():
    return [
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_with_oov', masked_tokens=[pad_token]),
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens),
        # Notice BOS never appears in ground truth.
        keras_metrics.MaskedCategoricalAccuracy(
            name='accuracy_no_oov_or_eos',
            masked_tokens=[pad_token, eos_token] + oov_tokens),
        keras_metrics.NumBatchesCounter(),
        keras_metrics.NumTokensCounter(masked_tokens=[pad_token])
    ]

  train_clientdata, validation_clientdata, test_clientdata = (
      tff.simulation.datasets.stackoverflow.load_data())

  vocab = stackoverflow_word_prediction.create_vocab(vocab_size)
  dataset_preprocess_comp = stackoverflow_dataset.create_preprocess_fn(
      vocab=vocab,
      num_oov_buckets=num_oov_buckets,
      client_batch_size=client_batch_size,
      max_sequence_length=sequence_length,
      max_elements_per_client=max_elements_per_user,
      feature_dtypes=train_clientdata.element_type_structure,
      sort_by_date=True)

  input_spec = dataset_preprocess_comp.type_signature.result.element

  model_fn = functools.partial(
      models.create_recurrent_reconstruction_model,
      vocab_size=vocab_size,
      num_oov_buckets=num_oov_buckets,
      embedding_size=embedding_size,
      latent_size=latent_size,
      num_layers=num_layers,
      input_spec=input_spec,
      global_variables_only=global_variables_only)

  def client_weight_fn(local_outputs):
    # Num_tokens is a tensor with type int64[1], to use as a weight need
    # a float32 scalar.
    return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

  iterative_process = iterative_process_builder(
      model_fn,
      loss_fn=loss_fn,
      metrics_fn=metrics_fn,
      client_weight_fn=client_weight_fn,
      dataset_split_fn_builder=functools.partial(
          federated_trainer_utils.build_dataset_split_fn,
          split_dataset_strategy=split_dataset_strategy,
          split_dataset_proportion=split_dataset_proportion))

  base_eval_computation = evaluation_computation_builder(
      model_fn,
      loss_fn=loss_fn,
      metrics_fn=metrics_fn,
      dataset_split_fn_builder=functools.partial(
          federated_trainer_utils.build_dataset_split_fn,
          split_dataset_strategy=split_dataset_strategy,
          split_dataset_proportion=split_dataset_proportion))

  if compose_dataset_computation:
    # Compose dataset computations with client training and evaluation to avoid
    # linear cost of computing centrally. This changes the expected input of
    # the `IterativeProcess` and `tff.Computation` to be a list of client IDs
    # instead of datasets.
    training_process = (
        tff.simulation.compose_dataset_computation_with_iterative_process(
            dataset_preprocess_comp, iterative_process))
    training_process = (
        tff.simulation.compose_dataset_computation_with_iterative_process(
            train_clientdata.dataset_computation, training_process))
    training_process.get_model_weights = iterative_process.get_model_weights

    base_eval_computation = (
        tff.simulation.compose_dataset_computation_with_computation(
            dataset_preprocess_comp, base_eval_computation))
    val_computation = (
        tff.simulation.compose_dataset_computation_with_computation(
            validation_clientdata.dataset_computation, base_eval_computation))
    test_computation = (
        tff.simulation.compose_dataset_computation_with_computation(
            test_clientdata.dataset_computation, base_eval_computation))

    # Create client sampling functions for each of train/val/test.
    # We need to sample client IDs, not datasets, and we do not need to apply
    # `dataset_preprocess_comp` since this is applied as part of the training
    # process and evaluation computation.
    train_client_datasets_fn = federated_trainer_utils.build_list_sample_fn(
        train_clientdata.client_ids, size=clients_per_round, replace=False)
    val_client_datasets_fn = federated_trainer_utils.build_list_sample_fn(
        validation_clientdata.client_ids, size=clients_per_round, replace=False)
    test_client_datasets_fn = federated_trainer_utils.build_list_sample_fn(
        test_clientdata.client_ids, size=clients_per_round, replace=False)
  else:
    training_process = iterative_process
    val_computation = base_eval_computation
    test_computation = base_eval_computation
    # Apply dataset computations.
    train_clientdata = train_clientdata.preprocess(dataset_preprocess_comp)
    validation_clientdata = validation_clientdata.preprocess(
        dataset_preprocess_comp)
    test_clientdata = test_clientdata.preprocess(dataset_preprocess_comp)

    # Create client sampling functions for each of train/val/test.
    train_client_datasets_fn = functools.partial(
        tff.simulation.build_uniform_sampling_fn(train_clientdata.client_ids),
        size=clients_per_round)
    val_client_datasets_fn = functools.partial(
        tff.simulation.build_uniform_sampling_fn(
            validation_clientdata.client_ids),
        size=clients_per_round)
    test_client_datasets_fn = functools.partial(
        tff.simulation.build_uniform_sampling_fn(test_clientdata.client_ids),
        size=clients_per_round)

  # Create final evaluation functions to pass to `training_loop`.
  val_fn = federated_trainer_utils.build_eval_fn(
      evaluation_computation=val_computation,
      client_datasets_fn=val_client_datasets_fn,
      get_model=training_process.get_model_weights)
  test_fn = federated_trainer_utils.build_eval_fn(
      evaluation_computation=test_computation,
      client_datasets_fn=test_client_datasets_fn,
      get_model=training_process.get_model_weights)
  test_fn = functools.partial(test_fn, round_num=0)

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=train_client_datasets_fn,
      validation_fn=val_fn,
      test_fn=test_fn,
      total_rounds=total_rounds,
      experiment_name=experiment_name,
      root_output_dir=root_output_dir,
      **kwargs)
Exemple #21
0
def run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, _ = emnist_dataset.get_federated_datasets(
        train_client_batch_size=FLAGS.client_batch_size,
        train_client_epochs_per_round=FLAGS.client_epochs_per_round,
        only_digits=False)

    _, emnist_test = emnist_dataset.get_centralized_datasets()

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    input_spec = example_dataset.element_spec

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)
    validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights)

    client_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'client')
    server_optimizer_fn = functools.partial(
        utils_impl.create_optimizer_from_flags, 'server')

    def tff_model_fn():
        keras_model = model_builder()
        return tff.learning.from_keras_model(keras_model,
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    if FLAGS.use_compression:
        # We create a `MeasuredProcess` for broadcast process and a
        # `MeasuredProcess` for aggregate process by providing the
        # `_broadcast_encoder_fn` and `_mean_encoder_fn` to corresponding utilities.
        # The fns are called once for each of the model weights created by
        # tff_model_fn, and return instances of appropriate encoders.
        encoded_broadcast_process = (
            tff.learning.framework.build_encoded_broadcast_process_from_model(
                tff_model_fn, _broadcast_encoder_fn))
        encoded_mean_process = (
            tff.learning.framework.build_encoded_mean_process_from_model(
                tff_model_fn, _mean_encoder_fn))
    else:
        encoded_broadcast_process = None
        encoded_mean_process = None

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=tff_model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn,
        aggregation_process=encoded_mean_process,
        broadcast_process=encoded_broadcast_process)

    # Log hyperparameters to CSV
    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                               FLAGS.experiment_name)
    utils_impl.create_directory_if_not_exists(results_dir)
    hparam_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=validation_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))
    tff.backends.native.set_local_execution_context(max_fanout=10)

    model_builder = functools.partial(
        stackoverflow_models.create_recurrent_model,
        vocab_size=FLAGS.vocab_size,
        embedding_size=FLAGS.embedding_size,
        latent_size=FLAGS.latent_size,
        num_layers=FLAGS.num_layers,
        shared_embedding=FLAGS.shared_embedding)

    loss_builder = functools.partial(
        tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True)

    special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size)
    pad_token = special_tokens.pad
    oov_tokens = special_tokens.oov
    eos_token = special_tokens.eos

    def metrics_builder():
        return [
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov',
                                                    masked_tokens=[pad_token] +
                                                    oov_tokens),
            # Notice BOS never appears in ground truth.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_no_oov_or_eos',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
            keras_metrics.NumBatchesCounter(),
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
        ]

    datasets = stackoverflow_dataset.construct_word_level_datasets(
        FLAGS.vocab_size, FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round, FLAGS.sequence_length,
        FLAGS.max_elements_per_user, FLAGS.num_validation_examples)
    train_dataset, validation_dataset, test_dataset = datasets

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0
    else:

        def client_weight_fn(local_outputs):
            return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32)

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=validation_dataset.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_clients_per_round=FLAGS.clients_per_round)

        weights_type = tff.learning.framework.weights_type_from_model(model_fn)
        aggregation_process = tff.utils.build_dp_aggregate_process(
            weights_type.trainable, dp_query)
    else:
        aggregation_process = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')

    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn,
        client_optimizer_fn=client_optimizer_fn,
        aggregation_process=aggregation_process)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        eval_dataset=validation_dataset,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    test_fn = training_utils.build_centralized_evaluate_fn(
        model_builder=model_builder,
        # Use both val and test for symmetry with other experiments, which
        # evaluate on the entire test set.
        eval_dataset=validation_dataset.concatenate(test_dataset),
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      test_fn=test_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
Exemple #23
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.client_learning_rate,
        decay_factor=FLAGS.client_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    server_lr_callback = callbacks.create_reduce_lr_on_plateau(
        learning_rate=FLAGS.server_learning_rate,
        decay_factor=FLAGS.server_decay_factor,
        min_delta=FLAGS.min_delta,
        min_lr=FLAGS.min_lr,
        window_size=FLAGS.window_size,
        patience=FLAGS.patience)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return adaptive_fed_avg.build_fed_avg_process(
            model_fn,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec, crop_size=FLAGS.cifar100_crop_size)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  emnist_task = 'digit_recognition'
  emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False)
  _, emnist_test = emnist_dataset.get_centralized_datasets(
      only_digits=False, emnist_task=emnist_task)

  train_preprocess_fn = emnist_dataset.create_preprocess_fn(
      num_epochs=FLAGS.client_epochs_per_round,
      batch_size=FLAGS.client_batch_size,
      emnist_task=emnist_task)

  input_spec = train_preprocess_fn.type_signature.result.element

  if FLAGS.model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=FLAGS.only_digits)
  elif FLAGS.model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model,
        only_digits=FLAGS.only_digits)
  elif FLAGS.model == '1m_cnn':
    model_builder = functools.partial(
        create_1m_cnn_model, only_digits=FLAGS.only_digits)
  else:
    raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model))

  logging.info('Training model:')
  logging.info(model_builder().summary())

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  compression_dict = utils_impl.lookup_flag_values(compression_flags)
  dp_dict = utils_impl.lookup_flag_values(dp_flags)

  # Most logic for deciding what baseline to run is here.
  aggregation_factory = fl_utils.build_aggregator(
      compression_flags=compression_dict,
      dp_flags=dp_dict,
      num_clients=len(emnist_train.client_ids),
      num_clients_per_round=FLAGS.clients_per_round,
      num_rounds=FLAGS.total_rounds,
      client_template=model_builder().trainable_variables)

  def tff_model_fn():
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        loss=loss_builder(),
        input_spec=input_spec,
        metrics=metrics_builder())

  server_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('server')
  client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('client')

  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=tff_model_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_weighting=tff.learning.ClientWeighting.UNIFORM,
      client_optimizer_fn=client_optimizer_fn,
      model_update_aggregation_factory=aggregation_factory)

  @tff.tf_computation(tf.string)
  def build_train_dataset_from_client_id(client_id):
    client_dataset = emnist_train.dataset_computation(client_id)
    return train_preprocess_fn(client_dataset)

  training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
      build_train_dataset_from_client_id, iterative_process)
  training_process.get_model_weights = iterative_process.get_model_weights

  client_ids_fn = functools.partial(
      tff.simulation.build_uniform_sampling_fn(
          emnist_train.client_ids,
          replace=False,
          random_seed=FLAGS.client_datasets_random_seed),
      size=FLAGS.clients_per_round)

  # We convert the output to a list (instead of an np.ndarray) so that it can
  # be used as input to the iterative process.
  client_sampling_fn = lambda x: list(client_ids_fn(x))

  evaluate_fn = tff.learning.build_federated_evaluation(tff_model_fn)

  def test_fn(state):
    return evaluate_fn(
        iterative_process.get_model_weights(state), [emnist_test])

  def validation_fn(state, round_num):
    del round_num
    return evaluate_fn(
        iterative_process.get_model_weights(state), [emnist_test])

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=client_sampling_fn,
      validation_fn=validation_fn,
      test_fn=test_fn,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      root_output_dir=FLAGS.root_output_dir,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
Exemple #25
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')

    client_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'client')
    server_lr_schedule = optimizer_utils.create_lr_schedule_from_flags(
        'server')

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
        client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.
      client_weight_fn: Optional function that takes the output of
        `model.report_local_outputs` and returns a tensor providing the weight
        in the federated average of model deltas. If not provided, the default
        is the total number of examples processed on device.

    Returns:
      A `tff.templates.IterativeProcess`.
    """

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'cifar100':
        runner_spec = federated_cifar100.configure_training(
            task_spec,
            crop_size=FLAGS.cifar100_crop_size,
            distort_train_images=FLAGS.cifar100_distort_train_images)
    elif FLAGS.task == 'emnist_cr':
        runner_spec = federated_emnist.configure_training(
            task_spec, model=FLAGS.emnist_cr_model)
    elif FLAGS.task == 'emnist_ae':
        runner_spec = federated_emnist_ae.configure_training(task_spec)
    elif FLAGS.task == 'shakespeare':
        runner_spec = federated_shakespeare.configure_training(
            task_spec, sequence_length=FLAGS.shakespeare_sequence_length)
    elif FLAGS.task == 'stackoverflow_nwp':
        runner_spec = federated_stackoverflow.configure_training(
            task_spec,
            vocab_size=FLAGS.so_nwp_vocab_size,
            num_oov_buckets=FLAGS.so_nwp_num_oov_buckets,
            sequence_length=FLAGS.so_nwp_sequence_length,
            max_elements_per_user=FLAGS.so_nwp_max_elements_per_user,
            num_validation_examples=FLAGS.so_nwp_num_validation_examples)
    elif FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(
            task_spec,
            vocab_tokens_size=FLAGS.so_lr_vocab_tokens_size,
            vocab_tags_size=FLAGS.so_lr_vocab_tags_size,
            max_elements_per_user=FLAGS.so_lr_max_elements_per_user,
            num_validation_examples=FLAGS.so_lr_num_validation_examples)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    _write_hparam_flags()

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
                      rounds_per_profile=FLAGS.rounds_per_profile)
Exemple #26
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    emnist_train, emnist_test = emnist_dataset.get_emnist_datasets(
        FLAGS.client_batch_size,
        FLAGS.client_epochs_per_round,
        only_digits=False)

    if FLAGS.model == 'cnn':
        model_builder = functools.partial(
            emnist_models.create_conv_dropout_model, only_digits=False)
    elif FLAGS.model == '2nn':
        model_builder = functools.partial(
            emnist_models.create_two_hidden_layer_model, only_digits=False)
    else:
        raise ValueError('Cannot handle model flag [{!s}].'.format(
            FLAGS.model))

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    if FLAGS.uniform_weighting:

        def client_weight_fn(local_outputs):
            del local_outputs
            return 1.0

    else:
        client_weight_fn = None  #  Defaults to the number of examples per client.

    def model_fn():
        return tff.learning.from_keras_model(
            model_builder(),
            loss_builder(),
            input_spec=emnist_test.element_spec,
            metrics=metrics_builder())

    if FLAGS.noise_multiplier is not None:
        if not FLAGS.uniform_weighting:
            raise ValueError(
                'Differential privacy is only implemented for uniform weighting.'
            )

        dp_query = tff.utils.build_dp_query(
            clip=FLAGS.clip,
            noise_multiplier=FLAGS.noise_multiplier,
            expected_total_weight=FLAGS.clients_per_round,
            adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate,
            target_unclipped_quantile=FLAGS.target_unclipped_quantile,
            clipped_count_budget_allocation=FLAGS.
            clipped_count_budget_allocation,
            expected_clients_per_round=FLAGS.clients_per_round,
            per_vector_clipping=FLAGS.per_vector_clipping,
            model=model_fn())

        weights_type = tff.learning.framework.weights_type_from_model(model_fn)
        aggregation_process = tff.utils.build_dp_aggregate_process(
            weights_type.trainable, dp_query)
    else:
        aggregation_process = None

    server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'server')
    client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags(
        'client')
    iterative_process = tff.learning.build_federated_averaging_process(
        model_fn=model_fn,
        server_optimizer_fn=server_optimizer_fn,
        client_weight_fn=client_weight_fn,
        client_optimizer_fn=client_optimizer_fn,
        aggregation_process=aggregation_process)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        emnist_train, FLAGS.clients_per_round)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
    training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags)

    training_loop.run(iterative_process=iterative_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      hparam_dict=hparam_dict,
                      **training_loop_dict)
Exemple #27
0
def run_federated(
        iterative_process_builder: Callable[...,
                                            tff.templates.IterativeProcess],
        client_epochs_per_round: int,
        client_batch_size: int,
        clients_per_round: int,
        max_batches_per_client: Optional[int] = -1,
        client_datasets_random_seed: Optional[int] = None,
        total_rounds: Optional[int] = 1500,
        experiment_name: Optional[str] = 'federated_emnist_ae',
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        max_eval_batches: Optional[int] = None,
        **kwargs):
    """Runs an iterative process on the EMNIST autoencoder task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Moreover, the server state must have an attribute `model` of type
  `tff.learning.ModelWeights`.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, and
      returns a `tff.templates.IterativeProcess`. The `model_fn` must return a
      `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    max_batches_per_client: An optional int specifying the number of batches
      taken by each client at each round. If `-1`, the entire client dataset is
      used.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    max_eval_batches: If set to a positive integer, evaluation datasets are
      capped to at most that many batches. If set to None or a nonpositive
      integer, the full evaluation datasets are used.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """

    emnist_train, _ = emnist_ae_dataset.get_emnist_datasets(
        client_batch_size=client_batch_size,
        client_epochs_per_round=client_epochs_per_round,
        max_batches_per_client=max_batches_per_client,
        only_digits=False)

    _, emnist_test = emnist_ae_dataset.get_centralized_datasets(
        train_batch_size=client_batch_size,
        max_test_batches=max_eval_batches,
        only_digits=False)

    input_spec = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0]).element_spec

    model_builder = emnist_ae_models.create_autoencoder_model

    loss_builder = functools.partial(tf.keras.losses.MeanSquaredError,
                                     reduction=tf.keras.losses.Reduction.SUM)
    metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    training_process = iterative_process_builder(tff_model_fn)

    client_datasets_fn = training_utils.build_client_datasets_fn(
        train_dataset=emnist_train,
        train_clients_per_round=clients_per_round,
        random_seed=client_datasets_random_seed)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=emnist_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())

    training_loop.run(iterative_process=training_process,
                      client_datasets_fn=client_datasets_fn,
                      validation_fn=evaluate_fn,
                      test_fn=evaluate_fn,
                      total_rounds=total_rounds,
                      experiment_name=experiment_name,
                      root_output_dir=root_output_dir,
                      **kwargs)
Exemple #28
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  emnist_train, _ = emnist_dataset.get_federated_datasets(
      train_client_batch_size=FLAGS.client_batch_size,
      train_client_epochs_per_round=FLAGS.client_epochs_per_round,
      only_digits=False)

  _, emnist_test = emnist_dataset.get_centralized_datasets()

  if FLAGS.model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=False)
  elif FLAGS.model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model, only_digits=False)
  else:
    raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model))

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  if FLAGS.uniform_weighting:
    client_weighting = tff.learning.ClientWeighting.UNIFORM
  else:
    client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES

  def model_fn():
    return tff.learning.from_keras_model(
        model_builder(),
        loss_builder(),
        input_spec=emnist_test.element_spec,
        metrics=metrics_builder())

  if FLAGS.noise_multiplier is not None:
    if not FLAGS.uniform_weighting:
      raise ValueError(
          'Differential privacy is only implemented for uniform weighting.')
    if FLAGS.noise_multiplier <= 0:
      raise ValueError('noise_multiplier must be positive if DP is enabled.')
    if FLAGS.clip is None or FLAGS.clip <= 0:
      raise ValueError('clip must be positive if DP is enabled.')

    if not FLAGS.adaptive_clip_learning_rate:
      aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
          noise_multiplier=FLAGS.noise_multiplier,
          clients_per_round=FLAGS.clients_per_round,
          clip=FLAGS.clip)
    else:
      if FLAGS.adaptive_clip_learning_rate <= 0:
        raise ValueError('adaptive_clip_learning_rate must be positive if '
                         'adaptive clipping is enabled.')
      aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive(
          noise_multiplier=FLAGS.noise_multiplier,
          clients_per_round=FLAGS.clients_per_round,
          initial_l2_norm_clip=FLAGS.clip,
          target_unclipped_quantile=FLAGS.target_unclipped_quantile,
          learning_rate=FLAGS.adaptive_clip_learning_rate)
  else:
    if FLAGS.uniform_weighting:
      aggregation_factory = tff.aggregators.UnweightedMeanFactory()
    else:
      aggregation_factory = tff.aggregators.MeanFactory()

  server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('server')
  client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags('client')
  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=model_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_weighting=client_weighting,
      client_optimizer_fn=client_optimizer_fn,
      model_update_aggregation_factory=aggregation_factory)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      emnist_train, FLAGS.clients_per_round)

  evaluate_fn = training_utils.build_centralized_evaluate_fn(
      eval_dataset=emnist_test,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)
  validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  # Log hyperparameters to CSV
  hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags())
  results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                             FLAGS.experiment_name)
  utils_impl.create_directory_if_not_exists(results_dir)
  hparam_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)

  training_loop.run(
      iterative_process=iterative_process,
      client_datasets_fn=client_datasets_fn,
      validation_fn=validation_fn,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      root_output_dir=FLAGS.root_output_dir,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint,
      rounds_per_profile=FLAGS.rounds_per_profile)
def run_federated(
        iterative_process_builder: Callable[...,
                                            tff.templates.IterativeProcess],
        client_epochs_per_round: int,
        client_batch_size: int,
        clients_per_round: int,
        schedule: Optional[str] = 'none',
        beta: Optional[float] = 0.,
        max_batches_per_client: Optional[int] = -1,
        client_datasets_random_seed: Optional[int] = None,
        crop_size: Optional[int] = 24,
        total_rounds: Optional[int] = 1500,
        experiment_name: Optional[str] = 'federated_cifar100',
        root_output_dir: Optional[str] = '/tmp/fed_opt',
        max_eval_batches: Optional[int] = None,
        **kwargs):
    """Runs an iterative process on the CIFAR-100 classification task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  Moreover, the server state must have an attribute `model` of type
  `tff.learning.ModelWeights`.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, and
      returns a `tff.templates.IterativeProcess`. The `model_fn` must return a
      `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    max_batches_per_client: An optional int specifying the number of batches
      taken by each client at each round. If `-1`, the entire client dataset is
      used.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    crop_size: An optional integer representing the resulting size of input
      images after preprocessing.
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    max_eval_batches: If set to a positive integer, evaluation datasets are
      capped to at most that many batches. If set to None or a nonpositive
      integer, the full evaluation datasets are used.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """

    crop_shape = (crop_size, crop_size, 3)

    cifar_train, _, fed_test_data = cifar100_dataset.get_federated_cifar100(
        client_epochs_per_round=client_epochs_per_round,
        train_batch_size=client_batch_size,
        crop_shape=crop_shape,
        max_batches_per_client=max_batches_per_client)

    _, cifar_test = cifar100_dataset.get_centralized_datasets(
        train_batch_size=client_batch_size,
        max_test_batches=max_eval_batches,
        crop_shape=crop_shape)

    input_spec = cifar_train.create_tf_dataset_for_client(
        cifar_train.client_ids[0]).element_spec

    model_builder = functools.partial(resnet_models.create_resnet18,
                                      input_shape=crop_shape,
                                      num_classes=NUM_CLASSES)

    loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
    metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

    def tff_model_fn() -> tff.learning.Model:
        return tff.learning.from_keras_model(keras_model=model_builder(),
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    training_process = iterative_process_builder(tff_model_fn)

    evaluate_fn = training_utils.build_evaluate_fn(
        eval_dataset=cifar_test,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    test_fn = training_utils.build_unweighted_test_fn(
        federated_eval_dataset=fed_test_data,
        model_builder=model_builder,
        loss_builder=loss_builder,
        metrics_builder=metrics_builder)

    logging.info('Training model:')
    logging.info(model_builder().summary())
    try:
        var = kwargs['hparam_dict']['var_q_clients']
        q_client = np.load(
            f'/home/monica/AVAIL_VECTORS/q_client_{var}_cifar.npy')
    except:
        logging.info(
            'Could not load q_client - initializing random availabilities')
        q_client = None

    if schedule == 'none':
        client_datasets_fn = training_utils.build_client_datasets_fn(
            train_dataset=cifar_train,
            train_clients_per_round=clients_per_round,
            random_seed=client_datasets_random_seed,
            var_q_clients=kwargs['hparam_dict']['var_q_clients'],
            f_mult=kwargs['hparam_dict']['f_mult'],
            f_intercept=kwargs['hparam_dict']['f_intercept'],
            sine_wave=kwargs['hparam_dict']['sine_wave'],
            use_p=True,
            q_client=q_client,
        )

        training_loop.run(iterative_process=training_process,
                          client_datasets_fn=client_datasets_fn,
                          validation_fn=evaluate_fn,
                          test_fn=test_fn,
                          total_rounds=total_rounds,
                          experiment_name=experiment_name,
                          root_output_dir=root_output_dir,
                          **kwargs)
    elif schedule == 'loss':
        if 'loss_pool_size' in kwargs['hparam_dict'] and kwargs['hparam_dict'][
                'loss_pool_size'] is not None:
            loss_pool_size = kwargs['hparam_dict']['loss_pool_size']
            logging.info(f'Loss pool size: {loss_pool_size}')
            client_datasets_fn = training_utils.build_client_datasets_fn(
                train_dataset=cifar_train,
                train_clients_per_round=loss_pool_size,
                random_seed=client_datasets_random_seed,
                var_q_clients=kwargs['hparam_dict']['var_q_clients'],
                f_mult=kwargs['hparam_dict']['f_mult'],
                f_intercept=kwargs['hparam_dict']['f_intercept'],
                sine_wave=kwargs['hparam_dict']['sine_wave'],
                use_p=True,
                q_client=q_client,
            )
            training_loop_loss.run(iterative_process=training_process,
                                   client_datasets_fn=client_datasets_fn,
                                   validation_fn=evaluate_fn,
                                   test_fn=test_fn,
                                   total_rounds=total_rounds,
                                   total_clients=loss_pool_size,
                                   experiment_name=experiment_name,
                                   root_output_dir=root_output_dir,
                                   **kwargs)
        else:
            raise ValueError('Loss pool size not specified')
    else:
        client_datasets_fn = training_utils.build_availability_client_datasets_fn(
            train_dataset=cifar_train,
            train_clients_per_round=clients_per_round,
            random_seed=client_datasets_random_seed,
            beta=beta,
            var_q_clients=kwargs['hparam_dict']['var_q_clients'],
            f_mult=kwargs['hparam_dict']['f_mult'],
            f_intercept=kwargs['hparam_dict']['f_intercept'],
            sine_wave=kwargs['hparam_dict']['sine_wave'],
            q_client=q_client,
        )
        training_loop_importance.run(iterative_process=training_process,
                                     client_datasets_fn=client_datasets_fn,
                                     validation_fn=evaluate_fn,
                                     test_fn=test_fn,
                                     total_rounds=total_rounds,
                                     experiment_name=experiment_name,
                                     root_output_dir=root_output_dir,
                                     **kwargs)
def run_federated(
    iterative_process_builder: Callable[..., tff.templates.IterativeProcess],
    client_epochs_per_round: int,
    client_batch_size: int,
    clients_per_round: int,
    client_datasets_random_seed: Optional[int] = None,
    model: Optional[str] = 'cnn',
    total_rounds: Optional[int] = 1500,
    experiment_name: Optional[str] = 'federated_emnist_cr',
    root_output_dir: Optional[str] = '/tmp/fed_opt',
    **kwargs):
  """Runs an iterative process on the EMNIST character recognition task.

  This method will load and pre-process dataset and construct a model used for
  the task. It then uses `iterative_process_builder` to create an iterative
  process that it applies to the task, using
  `federated_research.utils.training_loop`.

  We assume that the iterative process has the following functional type
  signatures:

    *   `initialize`: `( -> S@SERVER)` where `S` represents the server state.
    *   `next`: `<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>` where `S`
        represents the server state, `{B*}` represents the client datasets,
        and `T` represents a python `Mapping` object.

  The iterative process must also have a callable attribute `get_model_weights`
  that takes as input the state of the iterative process, and returns a
  `tff.learning.ModelWeights` object.

  Args:
    iterative_process_builder: A function that accepts a no-arg `model_fn`, and
      returns a `tff.templates.IterativeProcess`. The `model_fn` must return a
      `tff.learning.Model`.
    client_epochs_per_round: An integer representing the number of epochs of
      training performed per client in each training round.
    client_batch_size: An integer representing the batch size used on clients.
    clients_per_round: An integer representing the number of clients
      participating in each round.
    client_datasets_random_seed: An optional int used to seed which clients are
      sampled at each round. If `None`, no seed is used.
    model: A string specifying the model used for character recognition.
      Can be one of `cnn` and `2nn`, corresponding to a CNN model and a densely
      connected 2-layer model (respectively).
    total_rounds: The number of federated training rounds.
    experiment_name: The name of the experiment being run. This will be appended
      to the `root_output_dir` for purposes of writing outputs.
    root_output_dir: The name of the root output directory for writing
      experiment outputs.
    **kwargs: Additional arguments configuring the training loop. For details
      on supported arguments, see
      `federated_research/utils/training_utils.py`.
  """

  emnist_train, _ = emnist_dataset.get_federated_datasets(
      train_client_batch_size=client_batch_size,
      train_client_epochs_per_round=client_epochs_per_round,
      only_digits=False)

  _, emnist_test = emnist_dataset.get_centralized_datasets(only_digits=False)

  input_spec = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[0]).element_spec

  if model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=False)
  elif model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model, only_digits=False)
  else:
    raise ValueError(
        'Cannot handle model flag [{!s}], must be one of {!s}.'.format(
            model, EMNIST_MODELS))

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        input_spec=input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  training_process = iterative_process_builder(tff_model_fn)

  client_datasets_fn = training_utils.build_client_datasets_fn(
      dataset=emnist_train,
      clients_per_round=clients_per_round,
      random_seed=client_datasets_random_seed)

  test_fn = training_utils.build_centralized_evaluate_fn(
      eval_dataset=emnist_test,
      model_builder=model_builder,
      loss_builder=loss_builder,
      metrics_builder=metrics_builder)

  validation_fn = lambda model_weights, round_num: test_fn(model_weights)

  logging.info('Training model:')
  logging.info(model_builder().summary())

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=client_datasets_fn,
      validation_fn=validation_fn,
      test_fn=test_fn,
      total_rounds=total_rounds,
      experiment_name=experiment_name,
      root_output_dir=root_output_dir,
      **kwargs)