def test_remove_unused_flags_with_empty_optimizer(self):
        hparam_dict = collections.OrderedDict([('optimizer', '')])

        with self.assertRaisesRegex(
                ValueError, 'The flag optimizer was not set. '
                'Unable to determine the relevant optimizer.'):
            _ = utils_impl.remove_unused_flags(prefix=None,
                                               hparam_dict=hparam_dict)
 def test_remove_flags_with_optimizers_sharing_a_prefix(self):
     hparam_dict = collections.OrderedDict([('client_optimizer', 'adamW'),
                                            ('client_adam_momentum', 0.3),
                                            ('client_adamW_momentum', 0.5)])
     relevant_hparam_dict = utils_impl.remove_unused_flags(
         'client', hparam_dict)
     expected_flag_names = ['client_optimizer', 'client_adamW_momentum']
     self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
     self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adamW')
     self.assertEqual(relevant_hparam_dict['client_adamW_momentum'], 0.5)
 def test_remove_unused_flags_without_prefix(self):
     hparam_dict = collections.OrderedDict([('optimizer', 'sgd'),
                                            ('value', 0.1),
                                            ('sgd_momentum', 0.3),
                                            ('adam_momentum', 0.5)])
     relevant_hparam_dict = utils_impl.remove_unused_flags(
         prefix=None, hparam_dict=hparam_dict)
     expected_flag_names = ['optimizer', 'value', 'sgd_momentum']
     self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
     self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd')
     self.assertEqual(relevant_hparam_dict['value'], 0.1)
     self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
Exemplo n.º 4
0
def _run_experiment():
  """Data preprocessing and experiment execution."""
  emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
      only_digits=FLAGS.digit_only_emnist)

  def preprocess_train_dataset(dataset):
    """Preprocess training dataset."""
    return (dataset.map(reshape_emnist_element).shuffle(
        buffer_size=10000).repeat(FLAGS.client_epochs_per_round).batch(
            FLAGS.batch_size))

  def preprocess_test_dataset(dataset):
    """Preprocess testing dataset."""
    return dataset.map(reshape_emnist_element).batch(100, drop_remainder=False)

  emnist_train = emnist_train.preprocess(preprocess_train_dataset)
  emnist_test = preprocess_test_dataset(
      emnist_test.create_tf_dataset_from_all_clients())

  example_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[0])
  input_spec = example_dataset.element_spec

  def tff_model_fn():
    keras_model = model_builder()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=input_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

  def client_datasets_fn(round_num):
    """Returns a list of client datasets."""
    del round_num  # Unused.
    sampled_clients = random.sample(
        population=emnist_train.client_ids, k=FLAGS.train_clients_per_round)
    return [
        emnist_train.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

  def evaluate_fn(state):
    compiled_keras_model = compiled_eval_keras_model()
    state.model.assign_weights_to(compiled_keras_model)
    eval_metrics = compiled_keras_model.evaluate(emnist_test, verbose=0)
    return {
        'loss': eval_metrics[0],
        'sparse_categorical_accuracy': eval_metrics[1],
    }

  tf.io.gfile.makedirs(FLAGS.root_output_dir)
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])
  hparam_dict = utils_impl.remove_unused_flags('client', hparam_dict)

  metrics_hook = _MetricsHook(FLAGS.exp_name, FLAGS.root_output_dir,
                              hparam_dict)

  client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags('client')

  if FLAGS.server_optimizer == 'sgd':
    server_optimizer_fn = functools.partial(
        tf.keras.optimizers.SGD,
        learning_rate=FLAGS.server_learning_rate,
        momentum=FLAGS.server_momentum)
  elif FLAGS.server_optimizer == 'flars':
    server_optimizer_fn = functools.partial(
        flars_optimizer.FLARSOptimizer,
        learning_rate=FLAGS.server_learning_rate,
        momentum=FLAGS.server_momentum,
        max_ratio=FLAGS.max_ratio)
  else:
    raise ValueError('Optimizer %s is not supported.' % FLAGS.server_optimizer)

  _federated_averaging_training_loop(
      model_fn=tff_model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_optimizer_fn=server_optimizer_fn,
      client_datasets_fn=client_datasets_fn,
      evaluate_fn=evaluate_fn,
      total_rounds=FLAGS.total_rounds,
      rounds_per_eval=FLAGS.rounds_per_eval,
      metrics_hook=metrics_hook)
 def test_remove_unused_flags_without_optimizer_flag(self):
     hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'),
                                            ('client_sgd_momentum', 0.3)])
     with self.assertRaisesRegex(
             ValueError, 'The flag client_optimizer was not defined.'):
         _ = utils_impl.remove_unused_flags('client', hparam_dict)