Ejemplo n.º 1
0
  def test_remove_unused_flags_with_empty_optimizer(self):
    hparam_dict = collections.OrderedDict([('optimizer', '')])

    with self.assertRaisesRegex(
        ValueError, 'The flag optimizer was not set. '
        'Unable to determine the relevant optimizer.'):
      _ = utils_impl.remove_unused_flags(prefix=None, hparam_dict=hparam_dict)
Ejemplo n.º 2
0
 def test_remove_unused_flags_without_prefix(self):
   hparam_dict = collections.OrderedDict([('optimizer', 'sgd'), ('value', 0.1),
                                          ('sgd_momentum', 0.3),
                                          ('adam_momentum', 0.5)])
   relevant_hparam_dict = utils_impl.remove_unused_flags(
       prefix=None, hparam_dict=hparam_dict)
   expected_flag_names = ['optimizer', 'value', 'sgd_momentum']
   self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
   self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd')
   self.assertEqual(relevant_hparam_dict['value'], 0.1)
   self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
Ejemplo n.º 3
0
  def test_remove_unused_flags_with_prefix(self):
    hparam_dict = collections.OrderedDict([('client_optimizer', 'sgd'),
                                           ('non_client_value', 0.1),
                                           ('client_sgd_momentum', 0.3),
                                           ('client_adam_momentum', 0.5)])

    relevant_hparam_dict = utils_impl.remove_unused_flags('client', hparam_dict)
    expected_flag_names = [
        'client_optimizer', 'non_client_value', 'client_sgd_momentum'
    ]
    self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
    self.assertEqual(relevant_hparam_dict['client_optimizer'], 'sgd')
    self.assertEqual(relevant_hparam_dict['non_client_value'], 0.1)
    self.assertEqual(relevant_hparam_dict['client_sgd_momentum'], 0.3)
Ejemplo n.º 4
0
def _run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
        only_digits=FLAGS.digit_only_emnist)

    def preprocess_train_dataset(dataset):
        """Preprocess training dataset."""
        return (dataset.map(reshape_emnist_element).shuffle(
            buffer_size=10000).repeat(FLAGS.client_epochs_per_round).batch(
                FLAGS.batch_size))

    def preprocess_test_dataset(dataset):
        """Preprocess testing dataset."""
        return dataset.map(reshape_emnist_element).batch(100,
                                                         drop_remainder=False)

    emnist_train = emnist_train.preprocess(preprocess_train_dataset)
    emnist_test = preprocess_test_dataset(
        emnist_test.create_tf_dataset_from_all_clients())

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    input_spec = example_dataset.element_spec

    def tff_model_fn():
        keras_model = model_builder()
        return tff.learning.from_keras_model(keras_model,
                                             input_spec=input_spec,
                                             loss=loss_builder(),
                                             metrics=metrics_builder())

    def client_datasets_fn(round_num):
        """Returns a list of client datasets."""
        del round_num  # Unused.
        sampled_clients = random.sample(population=emnist_train.client_ids,
                                        k=FLAGS.train_clients_per_round)
        return [
            emnist_train.create_tf_dataset_for_client(client)
            for client in sampled_clients
        ]

    def evaluate_fn(state):
        compiled_keras_model = compiled_eval_keras_model()
        tff.learning.assign_weights_to_keras_model(compiled_keras_model,
                                                   state.model)
        eval_metrics = compiled_keras_model.evaluate(emnist_test, verbose=0)
        return {
            'loss': eval_metrics[0],
            'sparse_categorical_accuracy': eval_metrics[1],
        }

    tf.io.gfile.makedirs(FLAGS.root_output_dir)
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict = utils_impl.remove_unused_flags('client', hparam_dict)

    metrics_hook = _MetricsHook(FLAGS.exp_name, FLAGS.root_output_dir,
                                hparam_dict)

    client_optimizer_fn = lambda: utils_impl.create_optimizer_from_flags(
        'client')

    if FLAGS.server_optimizer == 'sgd':
        server_optimizer_fn = functools.partial(
            tf.keras.optimizers.SGD,
            learning_rate=FLAGS.server_learning_rate,
            momentum=FLAGS.server_momentum)
    elif FLAGS.server_optimizer == 'flars':
        server_optimizer_fn = functools.partial(
            flars_optimizer.FLARSOptimizer,
            learning_rate=FLAGS.server_learning_rate,
            momentum=FLAGS.server_momentum,
            max_ratio=FLAGS.max_ratio)
    else:
        raise ValueError('Optimizer %s is not supported.' %
                         FLAGS.server_optimizer)

    _federated_averaging_training_loop(model_fn=tff_model_fn,
                                       client_optimizer_fn=client_optimizer_fn,
                                       server_optimizer_fn=server_optimizer_fn,
                                       client_datasets_fn=client_datasets_fn,
                                       evaluate_fn=evaluate_fn,
                                       total_rounds=FLAGS.total_rounds,
                                       rounds_per_eval=FLAGS.rounds_per_eval,
                                       metrics_hook=metrics_hook)
Ejemplo n.º 5
0
 def test_remove_unused_flags_without_optimizer_flag(self):
     hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'),
                                            ('client_sgd_momentum', 0.3)])
     with self.assertRaisesRegex(
             ValueError, 'The flag client_optimizer was not defined.'):
         _ = utils_impl.remove_unused_flags('client', hparam_dict)
Ejemplo n.º 6
0
def run_experiment():
    """Data preprocessing and experiment execution."""
    emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
        only_digits=FLAGS.digit_only_emnist)

    example_tuple = collections.namedtuple('Example', ['x', 'y'])

    def element_fn(element):
        return example_tuple(x=tf.reshape(element['pixels'], [-1]),
                             y=tf.reshape(element['label'], [1]))

    def preprocess_train_dataset(dataset):
        """Preprocess training dataset."""
        return dataset.map(element_fn).apply(
            tf.data.experimental.shuffle_and_repeat(
                buffer_size=10000,
                count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size)

    def preprocess_test_dataset(dataset):
        """Preprocess testing dataset."""
        return dataset.map(element_fn).batch(100, drop_remainder=False)

    emnist_train = emnist_train.preprocess(preprocess_train_dataset)
    emnist_test = preprocess_test_dataset(
        emnist_test.create_tf_dataset_from_all_clients())

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(example_dataset)))

    def model_fn():
        keras_model = create_compiled_keras_model()
        return tff.learning.from_compiled_keras_model(keras_model,
                                                      sample_batch)

    def client_datasets_fn(round_num):
        """Returns a list of client datasets."""
        del round_num  # Unused.
        sampled_clients = random.sample(population=emnist_train.client_ids,
                                        k=FLAGS.train_clients_per_round)
        return [
            emnist_train.create_tf_dataset_for_client(client)
            for client in sampled_clients
        ]

    tf.io.gfile.makedirs(FLAGS.root_output_dir)
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])
    hparam_dict = utils_impl.remove_unused_flags('client', hparam_dict)

    metrics_hook = MetricsHook.build(FLAGS.exp_name, FLAGS.root_output_dir,
                                     emnist_test, hparam_dict)

    if FLAGS.server_optimizer == 'sgd':
        optimizer_fn = functools.partial(
            tf.keras.optimizers.SGD,
            learning_rate=FLAGS.server_learning_rate,
            momentum=FLAGS.server_momentum)
    elif FLAGS.server_optimizer == 'flars':
        optimizer_fn = functools.partial(
            flars_optimizer.FLARSOptimizer,
            learning_rate=FLAGS.server_learning_rate,
            momentum=FLAGS.server_momentum,
            max_ratio=FLAGS.max_ratio)
    else:
        raise ValueError('Optimizer %s is not supported.' %
                         FLAGS.server_optimizer)

    federated_averaging_training_loop(model_fn,
                                      optimizer_fn,
                                      client_datasets_fn,
                                      total_rounds=FLAGS.total_rounds,
                                      rounds_per_eval=FLAGS.rounds_per_eval,
                                      metrics_hook=metrics_hook)