def create_compiled_keras_model():
    """Create compiled keras model based on the original FedAvg CNN."""
    model = models.create_original_fedavg_cnn_model()

    model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
                  optimizer=utils_impl.get_optimizer_from_flags('client'),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return model
示例#2
0
 def test_get_server_optimizer_from_flags(self, optimizer_name, optimizer_cls):
   FLAGS['{}_optimizer'.format(TEST_SERVER_FLAG_PREFIX)].value = optimizer_name
   # Construct a default optimizer.
   default_optimizer = utils_impl.get_optimizer_from_flags(
       TEST_SERVER_FLAG_PREFIX)
   self.assertIsInstance(default_optimizer, optimizer_cls)
   # Set a flag to a non-default.
   FLAGS['{}_{}'.format(TEST_SERVER_FLAG_PREFIX,
                        'learning_rate')].value = 100.0
   custom_optimizer = utils_impl.get_optimizer_from_flags(
       TEST_SERVER_FLAG_PREFIX)
   self.assertIsInstance(custom_optimizer, optimizer_cls)
   self.assertEqual(custom_optimizer.get_config()['learning_rate'], 100.0)
   # Override the flag value.
   custom_optimizer = utils_impl.get_optimizer_from_flags(
       TEST_SERVER_FLAG_PREFIX, {'learning_rate': 5.0})
   self.assertIsInstance(custom_optimizer, optimizer_cls)
   self.assertEqual(custom_optimizer.get_config()['learning_rate'], 5.0)
示例#3
0
 def test_get_optimizer_from_flags_flags_set_not_for_optimizer(self):
   FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'sgd'
   # Set an Adam flag that isn't used in SGD.
   # We need to use `_parse_args` because that is the only way FLAGS is
   # notified that a non-default value is being used.
   FLAGS._parse_args(
       args=['--{}_adam_beta_1=0.5'.format(TEST_CLIENT_FLAG_PREFIX)],
       known_only=True)
   with self.assertRaisesRegex(
       ValueError,
       r'Commandline flags for .*\[sgd\].*\'test_client_adam_beta_1\'.*'):
     _ = utils_impl.get_optimizer_from_flags(TEST_CLIENT_FLAG_PREFIX)
def run_experiment():
    """Data preprocessing and experiment execution."""
    np.random.seed(FLAGS.random_seed)
    tf.random.set_random_seed(FLAGS.random_seed)

    emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
    if FLAGS.num_pseudo_clients > 1:
        emnist_train = tff.simulation.datasets.emnist.get_infinite(
            emnist_train, FLAGS.num_pseudo_clients)
        emnist_test = tff.simulation.datasets.emnist.get_infinite(
            emnist_test, FLAGS.num_pseudo_clients)

    example_tuple = collections.namedtuple('Example', ['x', 'y'])

    def element_fn(element):
        return example_tuple(x=tf.reshape(element['pixels'], [-1]),
                             y=tf.reshape(element['label'], [1]))

    def preprocess_train_dataset(dataset):
        """Preprocess training dataset."""
        return dataset.map(element_fn).apply(
            tf.data.experimental.shuffle_and_repeat(
                buffer_size=10000,
                count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size)

    def preprocess_test_dataset(dataset):
        """Preprocess testing dataset."""
        return dataset.map(element_fn).batch(100, drop_remainder=False)

    emnist_train = emnist_train.preprocess(preprocess_train_dataset)
    emnist_test = preprocess_test_dataset(
        emnist_test.create_tf_dataset_from_all_clients())

    example_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[0])
    sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                         next(iter(example_dataset)))

    def client_datasets_fn(round_num):
        """Returns a list of client datasets."""
        del round_num  # Unused.
        sampled_clients = np.random.choice(emnist_train.client_ids,
                                           size=FLAGS.train_clients_per_round,
                                           replace=False)
        return [
            emnist_train.create_tf_dataset_for_client(client)
            for client in sampled_clients
        ]

    tf.io.gfile.makedirs(FLAGS.root_output_dir)
    hparam_dict = collections.OrderedDict([(name, FLAGS[name].value)
                                           for name in hparam_flags])

    keras_model = create_compiled_keras_model()
    the_metrics_hook = metrics_hook.MetricsHook.build(FLAGS.exp_name,
                                                      FLAGS.root_output_dir,
                                                      emnist_test, hparam_dict,
                                                      keras_model)

    optimizer_fn = lambda: utils_impl.get_optimizer_from_flags('server')

    model = tff.learning.from_compiled_keras_model(keras_model, sample_batch)
    dp_query = tff.utils.build_dp_query(
        FLAGS.clip, FLAGS.noise_multiplier, FLAGS.train_clients_per_round,
        FLAGS.adaptive_clip_learning_rate, FLAGS.target_unclipped_quantile,
        FLAGS.clipped_count_budget_allocation, FLAGS.train_clients_per_round,
        FLAGS.use_per_vector, model)

    # Uniform weighting.
    def client_weight_fn(outputs):
        del outputs  # unused.
        return 1.0

    dp_aggregate_fn, _ = tff.utils.build_dp_aggregate(dp_query)

    def model_fn():
        keras_model = create_compiled_keras_model()
        return tff.learning.from_compiled_keras_model(keras_model,
                                                      sample_batch)

    training_loops.federated_averaging_training_loop(
        model_fn,
        optimizer_fn,
        client_datasets_fn,
        total_rounds=FLAGS.total_rounds,
        rounds_per_eval=FLAGS.rounds_per_eval,
        metrics_hook=the_metrics_hook,
        client_weight_fn=client_weight_fn,
        stateful_delta_aggregate_fn=dp_aggregate_fn)
示例#5
0
def run_experiment():
  """Data preprocessing and experiment execution."""
  np.random.seed(FLAGS.random_seed)
  tf.random.set_random_seed(FLAGS.random_seed)

  emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

  example_tuple = collections.namedtuple('Example', ['x', 'y'])

  def element_fn(element):
    return example_tuple(
        x=tf.reshape(element['pixels'], [-1]),
        y=tf.reshape(element['label'], [1]))

  def preprocess_train_dataset(dataset):
    """Preprocess training dataset."""
    return dataset.map(element_fn).apply(
        tf.data.experimental.shuffle_and_repeat(
            buffer_size=10000,
            count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size)

  def preprocess_test_dataset(dataset):
    """Preprocess testing dataset."""
    return dataset.map(element_fn).batch(100, drop_remainder=False)

  emnist_train = emnist_train.preprocess(preprocess_train_dataset)
  emnist_test = preprocess_test_dataset(
      emnist_test.create_tf_dataset_from_all_clients())

  example_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[0])
  sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                       next(iter(example_dataset)))

  def model_fn():
    keras_model = create_compiled_keras_model()
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

  def client_datasets_fn(round_num):
    """Returns a list of client datasets."""
    del round_num  # Unused.
    sampled_clients = np.random.choice(
        emnist_train.client_ids,
        size=FLAGS.train_clients_per_round,
        replace=False)
    return [
        emnist_train.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

  tf.io.gfile.makedirs(FLAGS.root_output_dir)
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])

  metrics_hook = MetricsHook.build(FLAGS.exp_name, FLAGS.root_output_dir,
                                   emnist_test, hparam_dict)

  optimizer_fn = lambda: utils_impl.get_optimizer_from_flags('server')

  training_loops.federated_averaging_training_loop(
      model_fn,
      optimizer_fn,
      client_datasets_fn,
      total_rounds=FLAGS.total_rounds,
      rounds_per_eval=FLAGS.rounds_per_eval,
      metrics_hook=metrics_hook)
示例#6
0
 def test_get_optimizer_from_flags(self):
     utils_impl.define_optimizer_flags('server',
                                       defaults=dict(learning_rate=1.25))
     self.assertEqual(FLAGS.server_learning_rate, 1.25)
     optimizer = utils_impl.get_optimizer_from_flags('server')
     self.assertEqual(optimizer.get_config()['learning_rate'], 1.25)
示例#7
0
 def test_get_optimizer_from_flags_invalid_optimizer(self):
     FLAGS['{}_optimizer'.format(TEST_CLIENT_FLAG_PREFIX)].value = 'foo'
     with self.assertRaisesRegex(ValueError, 'not a valid optimizer'):
         _ = utils_impl.get_optimizer_from_flags(TEST_CLIENT_FLAG_PREFIX)