def test_raises_no_repeat_and_no_take(self, mock_load_data, mock_load_word_counts): mock_load_data.return_value = (mock.Mock(), mock.Mock(), mock.Mock()) with self.assertRaisesRegex( ValueError, 'Argument client_epochs_per_round is set to -1'): stackoverflow_dataset.construct_word_level_datasets( vocab_size=100, client_batch_size=10, client_epochs_per_round=-1, max_batches_per_user=-1, max_seq_len=20, max_training_elements_per_user=128, num_validation_examples=500, num_oov_buckets=1)
def test_preprocess_applied(self, mock_load_data, mock_load_word_counts): if tf.config.list_logical_devices('GPU'): self.skipTest('skip GPU test') # Mock out the actual data loading from disk. Assert that the preprocessing # function is applied to the client data, and that only the ClientData # objects we desired are used. # # The correctness of the preprocessing function is tested in other tests. mock_train = mock.create_autospec(tff.simulation.ClientData) mock_validation = mock.create_autospec(tff.simulation.ClientData) mock_test = mock.create_autospec(tff.simulation.ClientData) mock_test_dataset = mock.Mock() mock_test.create_tf_dataset_from_all_clients = mock.Mock( return_value=mock_test_dataset) mock_load_data.return_value = (mock_train, mock_validation, mock_test) # Return a factor word dictionary. mock_load_word_counts.return_value = collections.OrderedDict(a=1) _, _, _ = stackoverflow_dataset.construct_word_level_datasets( vocab_size=1000, client_batch_size=10, client_epochs_per_round=1, max_batches_per_user=128, max_seq_len=20, max_training_elements_per_user=128, num_validation_examples=500, num_oov_buckets=1) # Assert the validation ClientData isn't used, and the test ClientData # is a single dataset over all the users. mock_load_data.assert_called_once() self.assertEmpty(mock_validation.mock_calls) self.assertEqual( mock_test.mock_calls, mock.call.create_tf_dataset_from_all_clients().call_list()) self.assertEqual(mock_train.mock_calls, mock.call.preprocess(mock.ANY).call_list()) # Assert the word counts were loaded once to apply to each dataset. mock_load_word_counts.assert_called_once()
def main(argv): if len(argv) > 1: raise app.UsageError('Expected no command-line arguments, ' 'got: {}'.format(argv)) tff.backends.native.set_local_execution_context(max_fanout=10) model_builder = functools.partial( stackoverflow_models.create_recurrent_model, vocab_size=FLAGS.vocab_size, embedding_size=FLAGS.embedding_size, latent_size=FLAGS.latent_size, num_layers=FLAGS.num_layers, shared_embedding=FLAGS.shared_embedding) loss_builder = functools.partial( tf.keras.losses.SparseCategoricalCrossentropy, from_logits=True) special_tokens = stackoverflow_dataset.get_special_tokens(FLAGS.vocab_size) pad_token = special_tokens.pad oov_tokens = special_tokens.oov eos_token = special_tokens.eos def metrics_builder(): return [ keras_metrics.MaskedCategoricalAccuracy(name='accuracy_with_oov', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy_no_oov', masked_tokens=[pad_token] + oov_tokens), # Notice BOS never appears in ground truth. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_no_oov_or_eos', masked_tokens=[pad_token, eos_token] + oov_tokens), keras_metrics.NumBatchesCounter(), keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), ] datasets = stackoverflow_dataset.construct_word_level_datasets( FLAGS.vocab_size, FLAGS.client_batch_size, FLAGS.client_epochs_per_round, FLAGS.sequence_length, FLAGS.max_elements_per_user, FLAGS.num_validation_examples) train_dataset, validation_dataset, test_dataset = datasets if FLAGS.uniform_weighting: def client_weight_fn(local_outputs): del local_outputs return 1.0 else: def client_weight_fn(local_outputs): return tf.cast(tf.squeeze(local_outputs['num_tokens']), tf.float32) def model_fn(): return tff.learning.from_keras_model( model_builder(), loss_builder(), input_spec=validation_dataset.element_spec, metrics=metrics_builder()) if FLAGS.noise_multiplier is not None: if not FLAGS.uniform_weighting: raise ValueError( 'Differential privacy is only implemented for uniform weighting.' ) dp_query = tff.utils.build_dp_query( clip=FLAGS.clip, noise_multiplier=FLAGS.noise_multiplier, expected_total_weight=FLAGS.clients_per_round, adaptive_clip_learning_rate=FLAGS.adaptive_clip_learning_rate, target_unclipped_quantile=FLAGS.target_unclipped_quantile, clipped_count_budget_allocation=FLAGS. clipped_count_budget_allocation, expected_clients_per_round=FLAGS.clients_per_round) weights_type = tff.learning.framework.weights_type_from_model(model_fn) aggregation_process = tff.utils.build_dp_aggregate_process( weights_type.trainable, dp_query) else: aggregation_process = None server_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'server') client_optimizer_fn = optimizer_utils.create_optimizer_fn_from_flags( 'client') iterative_process = tff.learning.build_federated_averaging_process( model_fn=model_fn, server_optimizer_fn=server_optimizer_fn, client_weight_fn=client_weight_fn, client_optimizer_fn=client_optimizer_fn, aggregation_process=aggregation_process) client_datasets_fn = training_utils.build_client_datasets_fn( train_dataset, FLAGS.clients_per_round) evaluate_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, eval_dataset=validation_dataset, loss_builder=loss_builder, metrics_builder=metrics_builder) test_fn = training_utils.build_centralized_evaluate_fn( model_builder=model_builder, # Use both val and test for symmetry with other experiments, which # evaluate on the entire test set. eval_dataset=validation_dataset.concatenate(test_dataset), loss_builder=loss_builder, metrics_builder=metrics_builder) logging.info('Training model:') logging.info(model_builder().summary()) hparam_dict = utils_impl.lookup_flag_values(utils_impl.get_hparam_flags()) training_loop_dict = utils_impl.lookup_flag_values(training_loop_flags) training_loop.run(iterative_process=iterative_process, client_datasets_fn=client_datasets_fn, validation_fn=evaluate_fn, test_fn=test_fn, hparam_dict=hparam_dict, **training_loop_dict)