def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') train_data, test_data = get_emnist_dataset() def tff_model_fn(): """Constructs a fully initialized model for use in federated averaging.""" keras_model = create_original_fedavg_cnn_model(only_digits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy() return simple_fedavg_tf.KerasModelWrapper(keras_model, test_data.element_spec, loss) iterative_process = simple_fedavg_tff.build_federated_averaging_process( tff_model_fn, server_optimizer_fn, client_optimizer_fn) server_state = iterative_process.initialize() metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') model = tff_model_fn() for round_num in range(FLAGS.total_rounds): sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False) sampled_train_data = [ train_data.create_tf_dataset_for_client(client) for client in sampled_clients ] server_state, train_metrics = iterative_process.next( server_state, sampled_train_data) print(f'Round {round_num} training loss: {train_metrics}') if round_num % FLAGS.rounds_per_eval == 0: model.from_weights(server_state.model_weights) accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data, metric) print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
def test_keras_evaluate(self): keras_model = _create_test_cnn_model() sample_data = [ collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) ] metric = tf.keras.metrics.SparseCategoricalAccuracy() accuracy = simple_fedavg_tf.keras_evaluate(keras_model, sample_data, metric) self.assertIsInstance(accuracy, tf.Tensor) self.assertBetween(accuracy, 0.0, 1.0)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # If GPU is provided, TFF will by default use the first GPU like TF. The # following lines will configure TFF to use multi-GPUs and distribute client # computation on the GPUs. Note that we put server computatoin on CPU to avoid # potential out of memory issue when a large number of clients is sampled per # round. The client devices below can be an empty list when no GPU could be # detected by TF. client_devices = tf.config.list_logical_devices('GPU') server_device = tf.config.list_logical_devices('CPU')[0] tff.backends.native.set_local_execution_context( server_tf_device=server_device, client_tf_devices=client_devices) train_data, test_data = get_emnist_dataset() def tff_model_fn(): """Constructs a fully initialized model for use in federated averaging.""" keras_model = create_original_fedavg_cnn_model(only_digits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) return simple_fedavg_tf.KerasModelWrapper(keras_model, test_data.element_spec, loss) iterative_process = simple_fedavg_tff.build_federated_averaging_process( tff_model_fn, server_optimizer_fn, client_optimizer_fn) server_state = iterative_process.initialize() metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') model = tff_model_fn() for round_num in range(FLAGS.total_rounds): sampled_clients = np.random.choice( train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False) sampled_train_data = [ train_data.create_tf_dataset_for_client(client) for client in sampled_clients ] server_state, train_metrics = iterative_process.next( server_state, sampled_train_data) print(f'Round {round_num} training loss: {train_metrics}') if round_num % FLAGS.rounds_per_eval == 0: model.from_weights(server_state.model_weights) accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data, metric) print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
def test_tff_learning_evaluate(self): it_process = simple_fedavg_tff.build_federated_averaging_process( _tff_learning_model_fn) server_state = it_process.initialize() sample_data = [ collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) ] keras_model = _create_test_cnn_model() server_state.model_weights.assign_weights_to(keras_model) sample_data = [ collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) ] metric = tf.keras.metrics.SparseCategoricalAccuracy() accuracy = simple_fedavg_tf.keras_evaluate(keras_model, sample_data, metric) self.assertIsInstance(accuracy, tf.Tensor) self.assertBetween(accuracy, 0.0, 1.0)