Ejemplo n.º 1
0
def create_compiled_keras_model():
    """Create compiled keras model."""
    model = models.create_keras_model()

    model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
                  optimizer=utils.get_optimizer_from_flags('client'),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return model
Ejemplo n.º 2
0
def run_experiment():
  """Data preprocessing and experiment execution."""
  np.random.seed(FLAGS.random_seed)
  tf.random.set_random_seed(FLAGS.random_seed)

  emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

  example_tuple = collections.namedtuple('Example', ['x', 'y'])

  def element_fn(element):
    return example_tuple(
        x=tf.reshape(element['pixels'], [-1]),
        y=tf.reshape(element['label'], [1]))

  def preprocess_train_dataset(dataset):
    """Preprocess training dataset."""
    return dataset.map(element_fn).apply(
        tf.data.experimental.shuffle_and_repeat(
            buffer_size=10000,
            count=FLAGS.client_epochs_per_round)).batch(FLAGS.batch_size)

  def preprocess_test_dataset(dataset):
    """Preprocess testing dataset."""
    return dataset.map(element_fn).batch(100, drop_remainder=False)

  emnist_train = emnist_train.preprocess(preprocess_train_dataset)
  emnist_test = preprocess_test_dataset(
      emnist_test.create_tf_dataset_from_all_clients())

  example_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[0])
  sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                       next(iter(example_dataset)))

  def model_fn():
    keras_model = create_compiled_keras_model()
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

  def client_datasets_fn(round_num):
    """Returns a list of client datasets."""
    del round_num  # Unused.
    sampled_clients = np.random.choice(
        emnist_train.client_ids,
        size=FLAGS.train_clients_per_round,
        replace=False)
    return [
        emnist_train.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

  tf.io.gfile.makedirs(FLAGS.root_output_dir)
  hparam_dict = collections.OrderedDict([
      (name, FLAGS[name].value) for name in hparam_flags
  ])

  metrics_hook = MetricsHook.build(FLAGS.exp_name, FLAGS.root_output_dir,
                                   emnist_test, hparam_dict)

  optimizer_fn = lambda: utils.get_optimizer_from_flags('server')

  training_loops.federated_averaging_training_loop(
      model_fn,
      optimizer_fn,
      client_datasets_fn,
      total_rounds=FLAGS.total_rounds,
      rounds_per_eval=FLAGS.rounds_per_eval,
      metrics_hook=metrics_hook)