def test_raises_no_repeat_and_no_take(self, mock_load_data, mock_load_word_counts): if tf.config.list_logical_devices('GPU'): self.skipTest('skip GPU test') mock_load_data.return_value = (mock.Mock(), mock.Mock(), mock.Mock()) with self.assertRaisesRegex( ValueError, 'client_epochs_per_round must be a positive integer.'): stackoverflow_word_prediction.get_federated_datasets( vocab_size=100, train_client_batch_size=10, train_client_epochs_per_round=-1, max_sequence_length=20, max_elements_per_train_client=128, num_oov_buckets=1)
def test_preprocess_applied(self, mock_load_data, mock_load_word_counts): if tf.config.list_logical_devices('GPU'): self.skipTest('skip GPU test') # Mock out the actual data loading from disk. Assert that the preprocessing # function is applied to the client data, and that only the ClientData # objects we desired are used. # # The correctness of the preprocessing function is tested in other tests. mock_train = mock.create_autospec(tff.simulation.datasets.ClientData) mock_validation = mock.create_autospec( tff.simulation.datasets.ClientData) mock_test = mock.create_autospec(tff.simulation.datasets.ClientData) mock_load_data.return_value = (mock_train, mock_validation, mock_test) # Return a factor word dictionary. mock_load_word_counts.return_value = collections.OrderedDict(a=1) _, _ = stackoverflow_word_prediction.get_federated_datasets( vocab_size=1000, train_client_batch_size=10, test_client_batch_size=100, train_client_epochs_per_round=1, test_client_epochs_per_round=1, max_sequence_length=20, max_elements_per_train_client=128, max_elements_per_test_client=-1, num_oov_buckets=1) # Assert the validation ClientData isn't used. mock_load_data.assert_called_once() self.assertEmpty(mock_validation.mock_calls) # Assert the training and testing data are preprocessed. self.assertEqual(mock_train.mock_calls, mock.call.preprocess(mock.ANY).call_list()) self.assertEqual(mock_test.mock_calls, mock.call.preprocess(mock.ANY).call_list()) # Assert the word counts were loaded once to apply to each dataset. mock_load_word_counts.assert_called_once()
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tff.backends.native.set_local_execution_context(max_fanout=10) model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( FLAGS.vocab_size) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), ] train_dataset, _ = stackoverflow_word_prediction.get_federated_datasets( vocab_size=FLAGS.vocab_size, train_client_batch_size=FLAGS.client_batch_size, train_client_epochs_per_round=FLAGS.client_epochs_per_round, max_sequence_length=FLAGS.sequence_length, max_elements_per_train_client=FLAGS.max_elements_per_user) _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=FLAGS.vocab_size, max_sequence_length=FLAGS.sequence_length, num_validation_examples=FLAGS.num_validation_examples) if FLAGS.uniform_weighting: def client_weight_fn(local_outputs): del local_outputs return 1.0 else: def client_weight_fn(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=validation_dataset.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) dp_query = tff.utils.build_dp_query( clip=FLAGS.clip, noise_multiplier=FLAGS.noise_multiplier, expected_total_weight=FLAGS.clients_per_round, adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate, target_unclipped_quantile=FLAGS.target_unclipped_quantile, clipped_count_budget_allocation=FLAGS. clipped_count_budget_allocation, expected_clients_per_round=FLAGS.clients_per_round) weights_type = tff.learning.framework.weights_type_from_model(model_fn) aggregation_process = tff.utils.build_dp_aggregate_process( weights_type.trainable, dp_query) else: aggregation_process = None server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, client_optimizer_fn=client_optimizer_fn, aggregation_process=aggregation_process) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda model_weights, round_num: evaluate_fn(model_weights) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, test_fn=test_fn, hparam_dict=hparam_dict, **training_loop_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tff.backends.native.set_local_execution_context(max_fanout=10) model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_word_prediction.get_special_tokens( FLAGS.vocab_size) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), ] train_dataset, _ = stackoverflow_word_prediction.get_federated_datasets( vocab_size=FLAGS.vocab_size, train_client_batch_size=FLAGS.client_batch_size, train_client_epochs_per_round=FLAGS.client_epochs_per_round, max_sequence_length=FLAGS.sequence_length, max_elements_per_train_client=FLAGS.max_elements_per_user) _, validation_dataset, test_dataset = stackoverflow_word_prediction.get_centralized_datasets( vocab_size=FLAGS.vocab_size, max_sequence_length=FLAGS.sequence_length, num_validation_examples=FLAGS.num_validation_examples) if FLAGS.uniform_weighting: client_weighting = tff.learning.ClientWeighting.UNIFORM else: client_weighting = tff.learning.ClientWeighting.NUM_EXAMPLES def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=validation_dataset.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) if FLAGS.noise_multiplier <= 0: raise ValueError( 'noise_multiplier must be positive if DP is enabled.') if FLAGS.clip is None or FLAGS.clip <= 0: raise ValueError('clip must be positive if DP is enabled.') if not FLAGS.adaptive_clip_learning_rate: aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, clip=FLAGS.clip) else: if FLAGS.adaptive_clip_learning_rate <= 0: raise ValueError( 'adaptive_clip_learning_rate must be positive if ' 'adaptive clipping is enabled.') aggregation_factory = tff.aggregators.DifferentiallyPrivateFactory.gaussian_adaptive( noise_multiplier=FLAGS.noise_multiplier, clients_per_round=FLAGS.clients_per_round, initial_l2_norm_clip=FLAGS.clip, target_unclipped_quantile=FLAGS.target_unclipped_quantile, learning_rate=FLAGS.adaptive_clip_learning_rate) else: if FLAGS.uniform_weighting: aggregation_factory = tff.aggregators.UnweightedMeanFactory() else: aggregation_factory = tff.aggregators.MeanFactory() server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weighting=client_weighting, client_optimizer_fn=client_optimizer_fn, model_update_aggregation_factory=aggregation_factory) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) validation_fn = lambda state, round_num: evaluate_fn(state.model) evaluate_test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = lambda state: evaluate_test_fn(state.model) logging.info('Training model:') logging.info(model_builder().summary()) # Log hyperparameters to CSV hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=validation_fn, test_fn=test_fn, total_rounds=FLAGS.total_rounds, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, rounds_per_eval=FLAGS.rounds_per_eval, rounds_per_checkpoint=FLAGS.rounds_per_checkpoint)