Exemple #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    train_data, test_data = get_emnist_dataset()

    def tff_model_fn():
        """Constructs a fully initialized model for use in federated averaging."""
        keras_model = create_original_fedavg_cnn_model(only_digits=True)
        loss = tf.keras.losses.SparseCategoricalCrossentropy()
        return stateful_fedavg_tf.KerasModelWrapper(keras_model,
                                                    test_data.element_spec,
                                                    loss)

    # Initialize client states.
    client_states = {
        client_id: stateful_fedavg_tf.ClientState(client_index=i,
                                                  iters_count=0)
        for i, client_id in enumerate(train_data.client_ids)
    }

    def get_sample_client_state():
        # Return a sample client state to initialize TFF types.
        return stateful_fedavg_tf.ClientState(client_index=-1, iters_count=0)

    iterative_process = stateful_fedavg_tff.build_federated_averaging_process(
        tff_model_fn, get_sample_client_state, server_optimizer_fn,
        client_optimizer_fn)
    server_state = iterative_process.initialize()

    metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
    model = tff_model_fn()
    for round_num in range(FLAGS.total_rounds):
        sampled_clients = np.random.choice(train_data.client_ids[:3],
                                           size=FLAGS.train_clients_per_round,
                                           replace=False)
        sampled_train_data = [
            train_data.create_tf_dataset_for_client(client)
            for client in sampled_clients
        ]
        sampled_client_states = [
            client_states[client] for client in sampled_clients
        ]  # Sample corresponding client states.
        server_state, train_metrics, updated_client_states = iterative_process.next(
            server_state, sampled_train_data, sampled_client_states)
        print(f'Round {round_num} training loss: {train_metrics}')
        # Save updated client states back into the global `client_states` structure.
        for client_state in updated_client_states:
            client_id = train_data.client_ids[client_state.client_index]
            client_states[client_id] = client_state
            print(f'Round {round_num} iterations on client '
                  f'{client_id}: {client_state .iters_count}')
        print(f'Round {round_num} total iterations on '
              f'sampled clients: {server_state.total_iters_count}')
        if round_num % FLAGS.rounds_per_eval == 0:
            model.from_weights(server_state.model_weights)
            accuracy = stateful_fedavg_tf.keras_evaluate(
                model.keras_model, test_data, metric)
            print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
Exemple #2
0
 def test_keras_evaluate(self):
   keras_model = _create_test_cnn_model()
   sample_data = [
       collections.OrderedDict(
           x=np.ones([1, 28, 28, 1], dtype=np.float32),
           y=np.ones([1], dtype=np.int32))
   ]
   metric = tf.keras.metrics.SparseCategoricalAccuracy()
   accuracy = stateful_fedavg_tf.keras_evaluate(keras_model, sample_data,
                                                metric)
   self.assertIsInstance(accuracy, tf.Tensor)
   self.assertBetween(accuracy, 0.0, 1.0)
Exemple #3
0
  def test_tff_learning_evaluate(self):
    it_process = stateful_fedavg_tff.build_federated_averaging_process(
        _tff_learning_model_fn, _create_one_client_state)
    server_state = it_process.initialize()
    sample_data = [
        collections.OrderedDict(
            x=np.ones([1, 28, 28, 1], dtype=np.float32),
            y=np.ones([1], dtype=np.int32))
    ]
    keras_model = _create_test_cnn_model()
    server_state.model_weights.assign_weights_to(keras_model)

    sample_data = [
        collections.OrderedDict(
            x=np.ones([1, 28, 28, 1], dtype=np.float32),
            y=np.ones([1], dtype=np.int32))
    ]
    metric = tf.keras.metrics.SparseCategoricalAccuracy()
    accuracy = stateful_fedavg_tf.keras_evaluate(keras_model, sample_data,
                                                 metric)
    self.assertIsInstance(accuracy, tf.Tensor)
    self.assertBetween(accuracy, 0.0, 1.0)