Exemplo n.º 1
0
 def model_fn():
   keras_model = models.create_original_fedavg_cnn_model()
   return tff.learning.from_keras_model(
       keras_model,
       dummy_batch=sample_batch,
       loss=tf.keras.losses.SparseCategoricalCrossentropy(),
       metrics=METRICS_LIST)
Exemplo n.º 2
0
 def model_fn():
   keras_model = models.create_original_fedavg_cnn_model()
   return tff.learning.from_keras_model(
       keras_model,
       dummy_batch=sample_batch,
       loss=tf.keras.losses.sparse_categorical_crossentropy,
       metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
Exemplo n.º 3
0
def create_compiled_keras_model():
    """Create compiled keras model based on the original FedAvg CNN."""
    model = models.create_original_fedavg_cnn_model()

    model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
                  optimizer=utils_impl.create_optimizer_from_flags('client'),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return model
Exemplo n.º 4
0
def create_compiled_keras_model():
    """Create compiled keras model."""
    model = models.create_original_fedavg_cnn_model(
        only_digits=FLAGS.digit_only_emnist)

    model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
                  optimizer=utils_impl.create_optimizer_from_flags('client'),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return model
Exemplo n.º 5
0
def create_compiled_keras_model():
  """Create compiled keras model."""
  if FLAGS.training_model == 'cnn':
    model = models.create_conv_dropout_model(only_digits=FLAGS.only_digits)
  elif FLAGS.training_model == 'orig_cnn':
    model = models.create_original_fedavg_cnn_model(
        only_digits=FLAGS.only_digits)
  elif FLAGS.training_model == '2nn':
    model = models.create_two_hidden_layer_model(only_digits=FLAGS.only_digits)
  elif FLAGS.training_model == 'resnet':
    model = models.create_resnet(num_blocks=9, only_digits=FLAGS.only_digits)
  else:
    raise ValueError('Model {} is not supported.'.format(FLAGS.training_model))

  model.compile(
      loss=tf.keras.losses.sparse_categorical_crossentropy,
      optimizer=tf.keras.optimizers.SGD(
          learning_rate=FLAGS.learning_rate, momentum=FLAGS.momentum),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  return model
Exemplo n.º 6
0
def run_experiment():
  """Data preprocessing and experiment execution."""
  emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
  if FLAGS.num_pseudo_clients > 1:
    emnist_train = tff.simulation.datasets.emnist.get_infinite(
        emnist_train, FLAGS.num_pseudo_clients)
    emnist_test = tff.simulation.datasets.emnist.get_infinite(
        emnist_test, FLAGS.num_pseudo_clients)

  example_tuple = collections.namedtuple('Example', ['x', 'y'])

  def element_fn(element):
    return example_tuple(
        x=tf.reshape(element['pixels'], [-1]),
        y=tf.reshape(element['label'], [1]))

  def preprocess_train_dataset(dataset):
    """Preprocess training dataset."""
    return (dataset.map(element_fn).shuffle(buffer_size=10000).repeat(
        FLAGS.client_epochs_per_round).batch(FLAGS.batch_size))

  def preprocess_test_dataset(dataset):
    """Preprocess testing dataset."""
    return dataset.map(element_fn).batch(100, drop_remainder=False)

  emnist_train = emnist_train.preprocess(preprocess_train_dataset)
  emnist_test = preprocess_test_dataset(
      emnist_test.create_tf_dataset_from_all_clients())

  example_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[0])
  sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                       next(iter(example_dataset)))

  def client_datasets_fn(round_num):
    """Returns a list of client datasets."""
    del round_num  # Unused.
    sampled_clients = np.random.choice(
        emnist_train.client_ids,
        size=FLAGS.train_clients_per_round,
        replace=False)
    return [
        emnist_train.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

  tf.io.gfile.makedirs(FLAGS.root_output_dir)
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])

  keras_model = models.create_original_fedavg_cnn_model()
  keras_model.compile(
      loss=tf.keras.losses.sparse_categorical_crossentropy,
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  the_metrics_hook = metrics_hook.MetricsHook.build(FLAGS.exp_name,
                                                    FLAGS.root_output_dir,
                                                    emnist_test, hparam_dict,
                                                    keras_model)

  client_optimizer_fn = functools.partial(
      utils_impl.create_optimizer_from_flags, 'client')
  server_optimizer_fn = functools.partial(
      utils_impl.create_optimizer_from_flags, 'server')

  def model_fn():
    keras_model = models.create_original_fedavg_cnn_model()
    return tff.learning.from_keras_model(
        keras_model,
        dummy_batch=sample_batch,
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  model = model_fn()
  dp_query = tff.utils.build_dp_query(
      FLAGS.clip, FLAGS.noise_multiplier, FLAGS.train_clients_per_round,
      FLAGS.adaptive_clip_learning_rate, FLAGS.target_unclipped_quantile,
      FLAGS.clipped_count_budget_allocation, FLAGS.train_clients_per_round,
      FLAGS.use_per_vector, model)

  # Uniform weighting.
  def client_weight_fn(outputs):
    del outputs  # unused.
    return 1.0

  dp_aggregate_fn, _ = tff.utils.build_dp_aggregate(dp_query)

  training_loops.federated_averaging_training_loop(
      model_fn,
      client_optimizer_fn=client_optimizer_fn,
      server_fn=server_optimizer_fn,
      client_datasets_fn=client_datasets_fn,
      total_rounds=FLAGS.total_rounds,
      rounds_per_eval=FLAGS.rounds_per_eval,
      metrics_hook=the_metrics_hook,
      client_weight_fn=client_weight_fn,
      stateful_delta_aggregate_fn=dp_aggregate_fn)