def test_dp_momentum_training(self, model_fn, optimzer_fn, total_rounds=3): def server_optimzier_fn(model_weights): model_weight_shape = tf.nest.map_structure(tf.shape, model_weights) return optimzer_fn(learning_rate=1.0, momentum=0.9, noise_std=1e-5, model_weight_shape=model_weight_shape) print('defining it process') it_process = dp_fedavg.build_federated_averaging_process( model_fn, server_optimizer_fn=server_optimzier_fn) print('next type', it_process.next.type_signature.parameter[0]) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for i in range(total_rounds): print('round', i) server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertEqual(i + 1, server_state.round_num) if 'server_state_type' in server_state.optimizer_state: self.assertEqual( i + 1, tree_aggregation.get_step_idx( server_state.optimizer_state['dp_tree_state'])) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_dpftal_training(self, total_rounds=5): def server_optimzier_fn(model_weights): model_weight_shape = tf.nest.map_structure(tf.shape, model_weights) return optimizer_utils.DPFTRLMServerOptimizer( learning_rate=0.1, momentum=0.9, noise_std=1e-5, model_weight_shape=model_weight_shape) it_process = dp_fedavg.build_federated_averaging_process( _rnn_model_fn, server_optimizer_fn=server_optimzier_fn) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict(x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32), y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for i in range(total_rounds): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertEqual(i + 1, server_state.round_num) self.assertEqual( i + 1, tree_aggregation.get_step_idx( server_state.optimizer_state['dp_tree_state'].level_state)) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_dp_momentum_training(self, model_fn, optimzer_fn, total_rounds=3): def server_optimzier_fn(model_weights): model_weight_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights) return optimzer_fn( learning_rate=1.0, momentum=0.9, noise_std=1e-5, model_weight_specs=model_weight_specs) it_process = dp_fedavg.build_federated_averaging_process( model_fn, server_optimizer_fn=server_optimzier_fn) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict( x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for i in range(total_rounds): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertEqual(i + 1, server_state.round_num) if server_state.optimizer_state is optimizer_utils.FTRLState: self.assertEqual( i + 1, tree_aggregation.get_step_idx( server_state.optimizer_state.dp_tree_state)) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_something(self, model_fn): it_process = dp_fedavg.build_federated_averaging_process(model_fn) self.assertIsInstance(it_process, tff.templates.IterativeProcess) federated_data_type = it_process.next.type_signature.parameter[1] self.assertEqual( str(federated_data_type), '{<x=float32[?,28,28,1],y=int32[?]>*}@CLIENTS')
def test_self_contained_example_custom_model(self): client_data = _create_client_data() train_data = [client_data()] trainer = dp_fedavg.build_federated_averaging_process(MnistModel) state = trainer.initialize() losses = [] for _ in range(2): state, loss = trainer.next(state, train_data) losses.append(loss) self.assertLess(losses[1], losses[0])
def test_dpftal_restart(self, total_rounds=3): def server_optimizer_fn(model_weights): model_weight_specs = tf.nest.map_structure( lambda v: tf.TensorSpec(v.shape, v.dtype), model_weights) return optimizer_utils.DPFTRLMServerOptimizer( learning_rate=0.1, momentum=0.9, noise_std=1e-5, model_weight_specs=model_weight_specs, efficient_tree=True, use_nesterov=True) it_process = dp_fedavg.build_federated_averaging_process( _rnn_model_fn, server_optimizer_fn=server_optimizer_fn, use_simulation_loop=True) server_state = it_process.initialize() model = _rnn_model_fn() optimizer = server_optimizer_fn(model.weights.trainable) def server_state_update(state): return tff.structure.update_struct( state, model=state.model, optimizer_state=optimizer.restart_dp_tree(state.model.trainable), round_num=state.round_num) def deterministic_batch(): return collections.OrderedDict( x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32), y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for i in range(total_rounds): server_state, loss = it_process.next(server_state, federated_data) server_state = server_state_update(server_state) loss_list.append(loss) self.assertEqual(i + 1, server_state.round_num) self.assertEqual( 0, tree_aggregation.get_step_idx( server_state.optimizer_state.dp_tree_state)) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_simple_training(self, model_fn): it_process = dp_fedavg.build_federated_averaging_process(model_fn) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for _ in range(3): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_client_adagrad_train(self): it_process = dp_fedavg.build_federated_averaging_process( _rnn_model_fn, client_optimizer_fn=functools.partial( tf.keras.optimizers.Adagrad, learning_rate=0.01)) server_state = it_process.initialize() def deterministic_batch(): return collections.OrderedDict( x=np.array([[0, 1, 2, 3, 4]], dtype=np.int32), y=np.array([[1, 2, 3, 4, 0]], dtype=np.int32)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] loss_list = [] for _ in range(3): server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_tff_learning_evaluate(self): it_process = dp_fedavg.build_federated_averaging_process( _tff_learning_model_fn) server_state = it_process.initialize() sample_data = [ collections.OrderedDict(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) ] keras_model = _create_test_cnn_model() server_state.model.assign_weights_to(keras_model) sample_data = [ collections.OrderedDict(x=np.ones([1, 28, 28, 1], dtype=np.float32), y=np.ones([1], dtype=np.int32)) ] metrics = [tf.keras.metrics.SparseCategoricalAccuracy()] metrics = dp_fedavg.keras_evaluate(keras_model, sample_data, metrics) accuracy = metrics[0].result() self.assertIsInstance(accuracy, tf.Tensor) self.assertBetween(accuracy, 0.0, 1.0)
def train_and_eval(): """Train and evaluate StackOver NWP task.""" logging.info('Show FLAGS for debugging:') for f in HPARAM_FLAGS: logging.info('%s=%s', f, FLAGS[f].value) train_dataset_computation, train_set, validation_set, test_set = _preprocess_stackoverflow( FLAGS.vocab_size, FLAGS.num_oov_buckets, FLAGS.sequence_length, FLAGS.num_validation_examples, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, FLAGS.max_elements_per_user) input_spec = train_dataset_computation.type_signature.result.element def tff_model_fn(): keras_model = models.create_recurrent_model( vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) return dp_fedavg.KerasModelWrapper(keras_model, input_spec, loss) noise_std = FLAGS.clip_norm * FLAGS.noise_multiplier / float( FLAGS.clients_per_round) server_optimizer_fn = functools.partial(_server_optimizer_fn, name=FLAGS.server_optimizer, learning_rate=FLAGS.server_lr, noise_std=noise_std) client_optimizer_fn = functools.partial(_client_optimizer_fn, name=FLAGS.client_optimizer, learning_rate=FLAGS.client_lr) iterative_process = dp_fedavg.build_federated_averaging_process( tff_model_fn, dp_clip_norm=FLAGS.clip_norm, server_optimizer_fn=server_optimizer_fn, client_optimizer_fn=client_optimizer_fn) iterative_process = tff.simulation.compose_dataset_computation_with_iterative_process( dataset_computation=train_dataset_computation, process=iterative_process) keras_metics = _get_stackoverflow_metrics(FLAGS.vocab_size, FLAGS.num_oov_buckets) model = tff_model_fn() def evaluate_fn(model_weights, dataset): model.from_weights(model_weights) metrics = dp_fedavg.keras_evaluate(model.keras_model, dataset, keras_metics) return collections.OrderedDict( (metric.name, metric.result().numpy()) for metric in metrics) hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in HPARAM_FLAGS]) if FLAGS.total_epochs is None: def client_dataset_ids_fn(round_num: int, epoch: int): return _sample_client_ids(FLAGS.clients_per_round, train_set, round_num, epoch) logging.info('Sample clients for max %d rounds', FLAGS.total_rounds) total_epochs = 0 else: client_shuffer = training_loop.ClientIDShuffler( FLAGS.clients_per_round, train_set) client_dataset_ids_fn = client_shuffer.sample_client_ids logging.info('Shuffle clients for max %d epochs and %d rounds', FLAGS.total_epochs, FLAGS.total_rounds) total_epochs = FLAGS.total_epochs training_loop.run(iterative_process, client_dataset_ids_fn, validation_fn=functools.partial(evaluate_fn, dataset=validation_set), total_epochs=total_epochs, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, train_eval_fn=None, test_fn=functools.partial(evaluate_fn, dataset=test_set), root_output_dir=FLAGS.root_output_dir, hparam_dict=hparam_dict, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_train_eval=2000)
def test_build_fedavg_process(self): it_process = dp_fedavg.build_federated_averaging_process(_rnn_model_fn) self.assertIsInstance(it_process, tff.templates.IterativeProcess) federated_type = it_process.next.type_signature.parameter self.assertEqual(str(federated_type[1]), '{<x=int32[?,5],y=int32[?,5]>*}@CLIENTS')
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) client_devices = tf.config.list_logical_devices('GPU') server_device = tf.config.list_logical_devices('CPU')[0] tff.backends.native.set_local_execution_context( max_fanout=2 * FLAGS.clients_per_round, server_tf_device=server_device, client_tf_devices=client_devices, clients_per_thread=FLAGS.clients_per_thread) logging.info('Show FLAGS for debugging:') for f in HPARAM_FLAGS: logging.info('%s=%s', f, FLAGS[f].value) train_data, test_data = _get_emnist_dataset( FLAGS.only_digits, FLAGS.client_epochs_per_round, FLAGS.client_batch_size, ) def tff_model_fn(): keras_model = _create_original_fedavg_cnn_model(FLAGS.only_digits) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) return dp_fedavg.KerasModelWrapper(keras_model, test_data.element_spec, loss) noise_std = FLAGS.clip_norm * FLAGS.noise_multiplier / float( FLAGS.clients_per_round) server_optimizer_fn = functools.partial(_server_optimizer_fn, name=FLAGS.server_optimizer, learning_rate=FLAGS.server_lr, noise_std=noise_std) client_optimizer_fn = functools.partial(_client_optimizer_fn, name=FLAGS.client_optimizer, learning_rate=FLAGS.client_lr) iterative_process = dp_fedavg.build_federated_averaging_process( tff_model_fn, dp_clip_norm=FLAGS.clip_norm, server_optimizer_fn=server_optimizer_fn, client_optimizer_fn=client_optimizer_fn) keras_metics = [tf.keras.metrics.SparseCategoricalAccuracy()] model = tff_model_fn() def evaluate_fn(model_weights, dataset): model.from_weights(model_weights) metrics = dp_fedavg.keras_evaluate(model.keras_model, dataset, keras_metics) return collections.OrderedDict( (metric.name, metric.result().numpy()) for metric in metrics) hparam_dict = collections.OrderedDict([(name, FLAGS[name].value) for name in HPARAM_FLAGS]) total_epochs = 0 if FLAGS.total_epochs is None else FLAGS.total_epochs training_loop.run(iterative_process, client_datasets_fn=_get_client_datasets_fn(train_data), validation_fn=functools.partial(evaluate_fn, dataset=test_data), total_rounds=FLAGS.total_rounds, total_epochs=total_epochs, experiment_name=FLAGS.experiment_name, train_eval_fn=None, test_fn=functools.partial(evaluate_fn, dataset=test_data), root_output_dir=FLAGS.root_output_dir, hparam_dict=hparam_dict, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint, rounds_per_train_eval=2000)