def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') train_data, test_data = get_emnist_dataset() def tff_model_fn(): """Constructs a fully initialized model for use in federated averaging.""" keras_model = create_original_fedavg_cnn_model(only_digits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy() return simple_fedavg_tf.KerasModelWrapper(keras_model, test_data.element_spec, loss) iterative_process = simple_fedavg_tff.build_federated_averaging_process( tff_model_fn, server_optimizer_fn, client_optimizer_fn) server_state = iterative_process.initialize() metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') model = tff_model_fn() for round_num in range(FLAGS.total_rounds): sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False) sampled_train_data = [ train_data.create_tf_dataset_for_client(client) for client in sampled_clients ] server_state, train_metrics = iterative_process.next( server_state, sampled_train_data) print(f'Round {round_num} training loss: {train_metrics}') if round_num % FLAGS.rounds_per_eval == 0: model.from_weights(server_state.model_weights) accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data, metric) print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
def test_something(self, model_fn): it_process = simple_fedavg_tff.build_federated_averaging_process( model_fn) self.assertIsInstance(it_process, tff.templates.IterativeProcess) federated_data_type = it_process.next.type_signature.parameter[1] self.assertEqual(str(federated_data_type), '{<x=float32[?,28,28,1],y=int32[?]>*}@CLIENTS')
def test_build_fedavg_process(self): it_process = simple_fedavg_tff.build_federated_averaging_process( _rnn_model_fn) self.assertIsInstance(it_process, tff.templates.IterativeProcess) federated_type = it_process.next.type_signature.parameter model_type = tff.learning.framework.weights_type_from_model(_rnn_model_fn) self.assertEqual( str(federated_type[0]), '<model_weights={},optimizer_state=<int64>,round_num=int32>@SERVER' .format(model_type)) self.assertEqual( str(federated_type[1]), '{<x=int32[?,5],y=int32[?,5]>*}@CLIENTS')
def test_self_contained_example_custom_model(self): client_data = _create_client_data() train_data = [client_data()] trainer = simple_fedavg_tff.build_federated_averaging_process(MnistModel) state = trainer.initialize() losses = [] for _ in range(2): state, loss = trainer.next(state, train_data) losses.append(loss) self.assertLess(losses[1], losses[0])
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # If GPU is provided, TFF will by default use the first GPU like TF. The # following lines will configure TFF to use multi-GPUs and distribute client # computation on the GPUs. Note that we put server computatoin on CPU to avoid # potential out of memory issue when a large number of clients is sampled per # round. The client devices below can be an empty list when no GPU could be # detected by TF. client_devices = tf.config.list_logical_devices('GPU') server_device = tf.config.list_logical_devices('CPU')[0] tff.backends.native.set_local_execution_context( server_tf_device=server_device, client_tf_devices=client_devices) train_data, test_data = get_emnist_dataset() def tff_model_fn(): """Constructs a fully initialized model for use in federated averaging.""" keras_model = create_original_fedavg_cnn_model(only_digits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) return simple_fedavg_tf.KerasModelWrapper(keras_model, test_data.element_spec, loss) iterative_process = simple_fedavg_tff.build_federated_averaging_process( tff_model_fn, server_optimizer_fn, client_optimizer_fn) server_state = iterative_process.initialize() metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') model = tff_model_fn() for round_num in range(FLAGS.total_rounds): sampled_clients = np.random.choice( train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False) sampled_train_data = [ train_data.create_tf_dataset_for_client(client) for client in sampled_clients ] server_state, train_metrics = iterative_process.next( server_state, sampled_train_data) print(f'Round {round_num} training loss: {train_metrics}') if round_num % FLAGS.rounds_per_eval == 0: model.from_weights(server_state.model_weights) accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data, metric) print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
def test_simple_training(self, model_fn): it_process = simple_fedavg_tff.build_federated_averaging_process(model_fn) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for _ in range(3): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_client_adagrad_train(self): it_process = simple_fedavg_tff.build_federated_averaging_process( _rnn_model_fn, client_optimizer_fn=functools.partial( tf.keras.optimizers.Adagrad, learning_rate=0.01)) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict( x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32), y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for _ in range(3): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_tff_learning_evaluate(self): it_process = simple_fedavg_tff.build_federated_averaging_process( _tff_learning_model_fn) server_state = it_process.initialize() sample_data = [ collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) ] keras_model = _create_test_cnn_model() server_state.model_weights.assign_weights_to(keras_model) sample_data = [ collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) ] metric = tf.keras.metrics.SparseCategoricalAccuracy() accuracy = simple_fedavg_tf.keras_evaluate(keras_model, sample_data, metric) self.assertIsInstance(accuracy, tf.Tensor) self.assertBetween(accuracy, 0.0, 1.0)