Example #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    train_data, test_data = get_emnist_dataset()

    def tff_model_fn():
        """Constructs a fully initialized model for use in federated averaging."""
        keras_model = create_original_fedavg_cnn_model(only_digits=True)
        loss = tf.keras.losses.SparseCategoricalCrossentropy()
        return simple_fedavg_tf.KerasModelWrapper(keras_model,
                                                  test_data.element_spec, loss)

    iterative_process = simple_fedavg_tff.build_federated_averaging_process(
        tff_model_fn, server_optimizer_fn, client_optimizer_fn)
    server_state = iterative_process.initialize()

    metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
    model = tff_model_fn()
    for round_num in range(FLAGS.total_rounds):
        sampled_clients = np.random.choice(train_data.client_ids,
                                           size=FLAGS.train_clients_per_round,
                                           replace=False)
        sampled_train_data = [
            train_data.create_tf_dataset_for_client(client)
            for client in sampled_clients
        ]
        server_state, train_metrics = iterative_process.next(
            server_state, sampled_train_data)
        print(f'Round {round_num} training loss: {train_metrics}')
        if round_num % FLAGS.rounds_per_eval == 0:
            model.from_weights(server_state.model_weights)
            accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model,
                                                       test_data, metric)
            print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
Example #2
0
 def test_something(self, model_fn):
     it_process = simple_fedavg_tff.build_federated_averaging_process(
         model_fn)
     self.assertIsInstance(it_process, tff.templates.IterativeProcess)
     federated_data_type = it_process.next.type_signature.parameter[1]
     self.assertEqual(str(federated_data_type),
                      '{<x=float32[?,28,28,1],y=int32[?]>*}@CLIENTS')
Example #3
0
 def test_build_fedavg_process(self):
   it_process = simple_fedavg_tff.build_federated_averaging_process(
       _rnn_model_fn)
   self.assertIsInstance(it_process, tff.templates.IterativeProcess)
   federated_type = it_process.next.type_signature.parameter
   model_type = tff.learning.framework.weights_type_from_model(_rnn_model_fn)
   self.assertEqual(
       str(federated_type[0]),
       '<model_weights={},optimizer_state=<int64>,round_num=int32>@SERVER'
       .format(model_type))
   self.assertEqual(
       str(federated_type[1]), '{<x=int32[?,5],y=int32[?,5]>*}@CLIENTS')
Example #4
0
  def test_self_contained_example_custom_model(self):

    client_data = _create_client_data()
    train_data = [client_data()]

    trainer = simple_fedavg_tff.build_federated_averaging_process(MnistModel)
    state = trainer.initialize()
    losses = []
    for _ in range(2):
      state, loss = trainer.next(state, train_data)
      losses.append(loss)
    self.assertLess(losses[1], losses[0])
Example #5
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # If GPU is provided, TFF will by default use the first GPU like TF. The
  # following lines will configure TFF to use multi-GPUs and distribute client
  # computation on the GPUs. Note that we put server computatoin on CPU to avoid
  # potential out of memory issue when a large number of clients is sampled per
  # round. The client devices below can be an empty list when no GPU could be
  # detected by TF.
  client_devices = tf.config.list_logical_devices('GPU')
  server_device = tf.config.list_logical_devices('CPU')[0]
  tff.backends.native.set_local_execution_context(
      server_tf_device=server_device, client_tf_devices=client_devices)

  train_data, test_data = get_emnist_dataset()

  def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""
    keras_model = create_original_fedavg_cnn_model(only_digits=True)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return simple_fedavg_tf.KerasModelWrapper(keras_model,
                                              test_data.element_spec, loss)

  iterative_process = simple_fedavg_tff.build_federated_averaging_process(
      tff_model_fn, server_optimizer_fn, client_optimizer_fn)
  server_state = iterative_process.initialize()

  metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
  model = tff_model_fn()
  for round_num in range(FLAGS.total_rounds):
    sampled_clients = np.random.choice(
        train_data.client_ids,
        size=FLAGS.train_clients_per_round,
        replace=False)
    sampled_train_data = [
        train_data.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]
    server_state, train_metrics = iterative_process.next(
        server_state, sampled_train_data)
    print(f'Round {round_num} training loss: {train_metrics}')
    if round_num % FLAGS.rounds_per_eval == 0:
      model.from_weights(server_state.model_weights)
      accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data,
                                                 metric)
      print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
Example #6
0
  def test_simple_training(self, model_fn):
    it_process = simple_fedavg_tff.build_federated_averaging_process(model_fn)
    server_state = it_process.initialize()

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.ones([1, 28, 28, 1], dtype=np.float32),
          y=np.ones([1], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]

    loss_list = []
    for _ in range(3):
      server_state, loss = it_process.next(server_state, federated_data)
      loss_list.append(loss)

    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #7
0
  def test_client_adagrad_train(self):
    it_process = simple_fedavg_tff.build_federated_averaging_process(
        _rnn_model_fn,
        client_optimizer_fn=functools.partial(
            tf.keras.optimizers.Adagrad, learning_rate=0.01))
    server_state = it_process.initialize()

    def deterministic_batch():
      return collections.OrderedDict(
          x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32),
          y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32))

    batch = tff.tf_computation(deterministic_batch)()
    federated_data = [[batch]]

    loss_list = []
    for _ in range(3):
      server_state, loss = it_process.next(server_state, federated_data)
      loss_list.append(loss)

    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Example #8
0
  def test_tff_learning_evaluate(self):
    it_process = simple_fedavg_tff.build_federated_averaging_process(
        _tff_learning_model_fn)
    server_state = it_process.initialize()
    sample_data = [
        collections.OrderedDict(
            x=np.ones([1, 28, 28, 1], dtype=np.float32),
            y=np.ones([1], dtype=np.int32))
    ]
    keras_model = _create_test_cnn_model()
    server_state.model_weights.assign_weights_to(keras_model)

    sample_data = [
        collections.OrderedDict(
            x=np.ones([1, 28, 28, 1], dtype=np.float32),
            y=np.ones([1], dtype=np.int32))
    ]
    metric = tf.keras.metrics.SparseCategoricalAccuracy()
    accuracy = simple_fedavg_tf.keras_evaluate(keras_model, sample_data, metric)
    self.assertIsInstance(accuracy, tf.Tensor)
    self.assertBetween(accuracy, 0.0, 1.0)