def test_create_client_optimizer_from_flags(self, optimizer_name,
                                             optimizer_cls):
     with flag_sandbox(
         {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): optimizer_name}):
         # Construct a default optimizer.
         default_optimizer = utils_impl.create_optimizer_from_flags(
             TEST_CLIENT_FLAG_PREFIX)
         self.assertIsInstance(default_optimizer, optimizer_cls)
         # Override the default flag value.
         overridden_learning_rate = 5.0
         custom_optimizer = utils_impl.create_optimizer_from_flags(
             TEST_CLIENT_FLAG_PREFIX,
             overrides={'learning_rate': overridden_learning_rate})
         self.assertIsInstance(custom_optimizer, optimizer_cls)
         self.assertEqual(custom_optimizer.get_config()['learning_rate'],
                          overridden_learning_rate)
         # Override learning rate flag.
         commandline_set_learning_rate = 100.0
         with flag_sandbox({
                 '{}_learning_rate'.format(TEST_CLIENT_FLAG_PREFIX):
                 commandline_set_learning_rate
         }):
             custom_optimizer = utils_impl.create_optimizer_from_flags(
                 TEST_CLIENT_FLAG_PREFIX)
             self.assertIsInstance(custom_optimizer, optimizer_cls)
             self.assertEqual(
                 custom_optimizer.get_config()['learning_rate'],
                 commandline_set_learning_rate)
Пример #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Expected no command-line arguments, '
                             'got: {}'.format(argv))

    client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags(
        'client')
    server_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags(
        'server')

    compression_dict = utils_impl.lookup_flag_values(compression_flags)
    dp_dict = utils_impl.lookup_flag_values(dp_flags)

    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model],
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`."""

        model_trainable_variables = model_fn().trainable_variables

        # Most logic for deciding what to run is here.
        aggregation_factory = fl_utils.build_aggregator(
            compression_flags=compression_dict,
            dp_flags=dp_dict,
            num_clients=get_total_num_clients(FLAGS.task),
            num_clients_per_round=FLAGS.clients_per_round,
            num_rounds=FLAGS.total_rounds,
            client_template=model_trainable_variables)

        return tff.learning.build_federated_averaging_process(
            model_fn=model_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_weighting=tff.learning.ClientWeighting.UNIFORM,
            client_optimizer_fn=client_optimizer_fn,
            model_update_aggregation_factory=aggregation_factory)

    task_spec = training_specs.TaskSpec(
        iterative_process_builder=iterative_process_builder,
        client_epochs_per_round=FLAGS.client_epochs_per_round,
        client_batch_size=FLAGS.client_batch_size,
        clients_per_round=FLAGS.clients_per_round,
        client_datasets_random_seed=FLAGS.client_datasets_random_seed)

    if FLAGS.task == 'stackoverflow_lr':
        runner_spec = federated_stackoverflow_lr.configure_training(task_spec)
    else:
        raise ValueError(
            '--task flag {} is not supported, must be one of {}.'.format(
                FLAGS.task, _SUPPORTED_TASKS))

    training_loop.run(iterative_process=runner_spec.iterative_process,
                      client_datasets_fn=runner_spec.client_datasets_fn,
                      validation_fn=runner_spec.validation_fn,
                      test_fn=runner_spec.test_fn,
                      total_rounds=FLAGS.total_rounds,
                      experiment_name=FLAGS.experiment_name,
                      root_output_dir=FLAGS.root_output_dir,
                      rounds_per_eval=FLAGS.rounds_per_eval,
                      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)
 def test_create_optimizer_from_flags_invalid_overrides(self):
     with flag_sandbox(
         {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd'}):
         with self.assertRaisesRegex(TypeError,
                                     'type `collections.Mapping`'):
             _ = utils_impl.create_optimizer_from_flags(
                 TEST_CLIENT_FLAG_PREFIX, overrides=[1, 2, 3])
 def test_create_optimizer_from_flags_flags_set_not_for_optimizer(self):
     with flag_sandbox(
         {'{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX): 'sgd'}):
         # Set an Adam flag that isn't used in SGD.
         # We need to use `_parse_args` because that is the only way FLAGS is
         # notified that a non-default value is being used.
         bad_adam_flag = '{}_adam_beta_1'.format(TEST_CLIENT_FLAG_PREFIX)
         FLAGS._parse_args(args=['--{}=0.5'.format(bad_adam_flag)],
                           known_only=True)
         with self.assertRaisesRegex(
                 ValueError,
                 r'Commandline flags for .*\[sgd\].*\'test_client_adam_beta_1\'.*'
         ):
             _ = utils_impl.create_optimizer_from_flags(
                 TEST_CLIENT_FLAG_PREFIX)
         FLAGS[bad_adam_flag].unparse()
Пример #5
0
def _run_experiment():
  """Data preprocessing and experiment execution."""
  emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
      only_digits=FLAGS.digit_only_emnist)

  def preprocess_train_dataset(dataset):
    """Preprocess training dataset."""
    return (dataset.map(reshape_emnist_element).shuffle(
        buffer_size=10000).repeat(FLAGS.client_epochs_per_round).batch(
            FLAGS.batch_size))

  def preprocess_test_dataset(dataset):
    """Preprocess testing dataset."""
    return dataset.map(reshape_emnist_element).batch(100, drop_remainder=False)

  emnist_train = emnist_train.preprocess(preprocess_train_dataset)
  emnist_test = preprocess_test_dataset(
      emnist_test.create_tf_dataset_from_all_clients())

  example_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[0])
  input_spec = example_dataset.element_spec

  def tff_model_fn():
    keras_model = model_builder()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  def client_datasets_fn(round_num):
    """Returns a list of client datasets."""
    del round_num  # Unused.
    sampled_clients = random.sample(
        population=emnist_train.client_ids, k=FLAGS.train_clients_per_round)
    return [
        emnist_train.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

  def evaluate_fn(state):
    compiled_keras_model = compiled_eval_keras_model()
    state.model.assign_weights_to(compiled_keras_model)
    eval_metrics = compiled_keras_model.evaluate(emnist_test, verbose=0)
    return {
        'loss': eval_metrics[0],
        'sparse_categorical_accuracy': eval_metrics[1],
    }

  tf.io.gfile.makedirs(FLAGS.root_output_dir)
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  hparam_dict = utils_impl.remove_unused_flags('client', hparam_dict)

  metrics_hook = _MetricsHook(FLAGS.exp_name, FLAGS.root_output_dir,
                              hparam_dict)

  client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('client')

  if FLAGS.server_optimizer == 'sgd':
    server_optimizer_fn = functools.partial(
        tf.keras.optimizers.SGD,
        learning_rate=FLAGS.server_learning_rate,
        momentum=FLAGS.server_momentum)
  elif FLAGS.server_optimizer == 'flars':
    server_optimizer_fn = functools.partial(
        flars_optimizer.FLARSOptimizer,
        learning_rate=FLAGS.server_learning_rate,
        momentum=FLAGS.server_momentum,
        max_ratio=FLAGS.max_ratio)
  else:
    raise ValueError('Optimizer %s is not supported.' % FLAGS.server_optimizer)

  _federated_averaging_training_loop(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_datasets_fn=client_datasets_fn,
      evaluate_fn=evaluate_fn,
      total_rounds=FLAGS.total_rounds,
      rounds_per_eval=FLAGS.rounds_per_eval,
      metrics_hook=metrics_hook)
 def test_create_optimizer_from_flags_invalid_optimizer(self):
     FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo'
     with self.assertRaisesRegex(ValueError, 'not a valid optimizer'):
         _ = utils_impl.create_optimizer_from_flags(TEST_CLIENT_FLAG_PREFIX)
Пример #7
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Expected no command-line arguments, '
                         'got: {}'.format(argv))

  emnist_task = 'digit_recognition'
  emnist_train, _ = tff.simulation.datasets.emnist.load_data(only_digits=False)
  _, emnist_test = emnist_dataset.get_centralized_datasets(
      only_digits=False, emnist_task=emnist_task)

  train_preprocess_fn = emnist_dataset.create_preprocess_fn(
      num_epochs=FLAGS.client_epochs_per_round,
      batch_size=FLAGS.client_batch_size,
      emnist_task=emnist_task)

  input_spec = train_preprocess_fn.type_signature.result.element

  if FLAGS.model == 'cnn':
    model_builder = functools.partial(
        emnist_models.create_conv_dropout_model, only_digits=FLAGS.only_digits)
  elif FLAGS.model == '2nn':
    model_builder = functools.partial(
        emnist_models.create_two_hidden_layer_model,
        only_digits=FLAGS.only_digits)
  elif FLAGS.model == '1m_cnn':
    model_builder = functools.partial(
        create_1m_cnn_model, only_digits=FLAGS.only_digits)
  else:
    raise ValueError('Cannot handle model flag [{!s}].'.format(FLAGS.model))

  logging.info('Training model:')
  logging.info(model_builder().summary())

  loss_builder = tf.keras.losses.SparseCategoricalCrossentropy
  metrics_builder = lambda: [tf.keras.metrics.SparseCategoricalAccuracy()]

  compression_dict = utils_impl.lookup_flag_values(compression_flags)
  dp_dict = utils_impl.lookup_flag_values(dp_flags)

  # Most logic for deciding what baseline to run is here.
  aggregation_factory = fl_utils.build_aggregator(
      compression_flags=compression_dict,
      dp_flags=dp_dict,
      num_clients=len(emnist_train.client_ids),
      num_clients_per_round=FLAGS.clients_per_round,
      num_rounds=FLAGS.total_rounds,
      client_template=model_builder().trainable_variables)

  def tff_model_fn():
    return tff.learning.from_keras_model(
        keras_model=model_builder(),
        loss=loss_builder(),
        input_spec=input_spec,
        metrics=metrics_builder())

  server_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('server')
  client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('client')

  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=tff_model_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_weighting=tff.learning.ClientWeighting.UNIFORM,
      client_optimizer_fn=client_optimizer_fn,
      model_update_aggregation_factory=aggregation_factory)

  @tff.tf_computation(tf.string)
  def build_train_dataset_from_client_id(client_id):
    client_dataset = emnist_train.dataset_computation(client_id)
    return train_preprocess_fn(client_dataset)

  training_process = tff.simulation.compose_dataset_computation_with_iterative_process(
      build_train_dataset_from_client_id, iterative_process)
  training_process.get_model_weights = iterative_process.get_model_weights

  client_ids_fn = functools.partial(
      tff.simulation.build_uniform_sampling_fn(
          emnist_train.client_ids,
          replace=False,
          random_seed=FLAGS.client_datasets_random_seed),
      size=FLAGS.clients_per_round)

  # We convert the output to a list (instead of an np.ndarray) so that it can
  # be used as input to the iterative process.
  client_sampling_fn = lambda x: list(client_ids_fn(x))

  evaluate_fn = tff.learning.build_federated_evaluation(tff_model_fn)

  def test_fn(state):
    return evaluate_fn(
        iterative_process.get_model_weights(state), [emnist_test])

  def validation_fn(state, round_num):
    del round_num
    return evaluate_fn(
        iterative_process.get_model_weights(state), [emnist_test])

  training_loop.run(
      iterative_process=training_process,
      client_datasets_fn=client_sampling_fn,
      validation_fn=validation_fn,
      test_fn=test_fn,
      total_rounds=FLAGS.total_rounds,
      experiment_name=FLAGS.experiment_name,
      root_output_dir=FLAGS.root_output_dir,
      rounds_per_eval=FLAGS.rounds_per_eval,
      rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)